Search in sources :

Example 11 with WhileStatementBlock

use of org.apache.sysml.parser.WhileStatementBlock in project incubator-systemml by apache.

the class ProgramRecompiler method rFindAndRecompileIndexingHOP.

/**
 * NOTE: if force is set, we set and recompile the respective indexing hops;
 * otherwise, we release the forced exec type and recompile again. Hence,
 * any changes can be exactly reverted with the same access behavior.
 *
 * @param sb statement block
 * @param pb program block
 * @param var variable
 * @param ec execution context
 * @param force if true, set and recompile the respective indexing hops
 */
public static void rFindAndRecompileIndexingHOP(StatementBlock sb, ProgramBlock pb, String var, ExecutionContext ec, boolean force) {
    if (pb instanceof IfProgramBlock && sb instanceof IfStatementBlock) {
        IfProgramBlock ipb = (IfProgramBlock) pb;
        IfStatementBlock isb = (IfStatementBlock) sb;
        IfStatement is = (IfStatement) sb.getStatement(0);
        // process if condition
        if (isb.getPredicateHops() != null)
            ipb.setPredicate(rFindAndRecompileIndexingHOP(isb.getPredicateHops(), ipb.getPredicate(), var, ec, force));
        // process if branch
        int len = is.getIfBody().size();
        for (int i = 0; i < ipb.getChildBlocksIfBody().size() && i < len; i++) {
            ProgramBlock lpb = ipb.getChildBlocksIfBody().get(i);
            StatementBlock lsb = is.getIfBody().get(i);
            rFindAndRecompileIndexingHOP(lsb, lpb, var, ec, force);
        }
        // process else branch
        if (ipb.getChildBlocksElseBody() != null) {
            int len2 = is.getElseBody().size();
            for (int i = 0; i < ipb.getChildBlocksElseBody().size() && i < len2; i++) {
                ProgramBlock lpb = ipb.getChildBlocksElseBody().get(i);
                StatementBlock lsb = is.getElseBody().get(i);
                rFindAndRecompileIndexingHOP(lsb, lpb, var, ec, force);
            }
        }
    } else if (pb instanceof WhileProgramBlock && sb instanceof WhileStatementBlock) {
        WhileProgramBlock wpb = (WhileProgramBlock) pb;
        WhileStatementBlock wsb = (WhileStatementBlock) sb;
        WhileStatement ws = (WhileStatement) sb.getStatement(0);
        // process while condition
        if (wsb.getPredicateHops() != null)
            wpb.setPredicate(rFindAndRecompileIndexingHOP(wsb.getPredicateHops(), wpb.getPredicate(), var, ec, force));
        // process body
        // robustness for potentially added problem blocks
        int len = ws.getBody().size();
        for (int i = 0; i < wpb.getChildBlocks().size() && i < len; i++) {
            ProgramBlock lpb = wpb.getChildBlocks().get(i);
            StatementBlock lsb = ws.getBody().get(i);
            rFindAndRecompileIndexingHOP(lsb, lpb, var, ec, force);
        }
    } else if (// for or parfor
    pb instanceof ForProgramBlock && sb instanceof ForStatementBlock) {
        ForProgramBlock fpb = (ForProgramBlock) pb;
        ForStatementBlock fsb = (ForStatementBlock) sb;
        ForStatement fs = (ForStatement) fsb.getStatement(0);
        if (fsb.getFromHops() != null)
            fpb.setFromInstructions(rFindAndRecompileIndexingHOP(fsb.getFromHops(), fpb.getFromInstructions(), var, ec, force));
        if (fsb.getToHops() != null)
            fpb.setToInstructions(rFindAndRecompileIndexingHOP(fsb.getToHops(), fpb.getToInstructions(), var, ec, force));
        if (fsb.getIncrementHops() != null)
            fpb.setIncrementInstructions(rFindAndRecompileIndexingHOP(fsb.getIncrementHops(), fpb.getIncrementInstructions(), var, ec, force));
        // process body
        // robustness for potentially added problem blocks
        int len = fs.getBody().size();
        for (int i = 0; i < fpb.getChildBlocks().size() && i < len; i++) {
            ProgramBlock lpb = fpb.getChildBlocks().get(i);
            StatementBlock lsb = fs.getBody().get(i);
            rFindAndRecompileIndexingHOP(lsb, lpb, var, ec, force);
        }
    } else // last level program block
    {
        try {
            // process actual hops
            boolean ret = false;
            Hop.resetVisitStatus(sb.getHops());
            if (force) {
                // set forced execution type
                for (Hop h : sb.getHops()) ret |= rFindAndSetCPIndexingHOP(h, var);
            } else {
                // release forced execution type
                for (Hop h : sb.getHops()) ret |= rFindAndReleaseIndexingHOP(h, var);
            }
            // recompilation on-demand
            if (ret) {
                // construct new instructions
                ArrayList<Instruction> newInst = Recompiler.recompileHopsDag(sb, sb.getHops(), ec.getVariables(), null, true, false, 0);
                pb.setInstructions(newInst);
            }
        } catch (Exception ex) {
            throw new DMLRuntimeException(ex);
        }
    }
}
Also used : IfProgramBlock(org.apache.sysml.runtime.controlprogram.IfProgramBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) ForProgramBlock(org.apache.sysml.runtime.controlprogram.ForProgramBlock) Hop(org.apache.sysml.hops.Hop) WhileProgramBlock(org.apache.sysml.runtime.controlprogram.WhileProgramBlock) WhileStatement(org.apache.sysml.parser.WhileStatement) BinaryCPInstruction(org.apache.sysml.runtime.instructions.cp.BinaryCPInstruction) FunctionCallCPInstruction(org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction) Instruction(org.apache.sysml.runtime.instructions.Instruction) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) IfStatement(org.apache.sysml.parser.IfStatement) ForProgramBlock(org.apache.sysml.runtime.controlprogram.ForProgramBlock) IfProgramBlock(org.apache.sysml.runtime.controlprogram.IfProgramBlock) ProgramBlock(org.apache.sysml.runtime.controlprogram.ProgramBlock) WhileProgramBlock(org.apache.sysml.runtime.controlprogram.WhileProgramBlock) ForStatement(org.apache.sysml.parser.ForStatement) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock)

Example 12 with WhileStatementBlock

use of org.apache.sysml.parser.WhileStatementBlock in project incubator-systemml by apache.

the class IPAPassFlagFunctionsRecompileOnce method rFlagFunctionForRecompileOnce.

/**
 * Returns true if this statementblock requires recompilation inside a
 * loop statement block.
 *
 * @param sb statement block
 * @param inLoop true if in loop
 * @return true if statement block requires recompilation inside a loop statement block
 */
public boolean rFlagFunctionForRecompileOnce(StatementBlock sb, boolean inLoop) {
    boolean ret = false;
    if (sb instanceof FunctionStatementBlock) {
        FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
        FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
        for (StatementBlock c : fstmt.getBody()) ret |= rFlagFunctionForRecompileOnce(c, inLoop);
    } else if (sb instanceof WhileStatementBlock) {
        // recompilation information not available at this point
        // hence, mark any loop statement block
        ret = true;
    } else if (sb instanceof IfStatementBlock) {
        IfStatementBlock isb = (IfStatementBlock) sb;
        IfStatement istmt = (IfStatement) isb.getStatement(0);
        ret |= (inLoop && isb.requiresPredicateRecompilation());
        for (StatementBlock c : istmt.getIfBody()) ret |= rFlagFunctionForRecompileOnce(c, inLoop);
        for (StatementBlock c : istmt.getElseBody()) ret |= rFlagFunctionForRecompileOnce(c, inLoop);
    } else if (sb instanceof ForStatementBlock) {
        // recompilation information not available at this point
        // hence, mark any loop statement block
        ret = true;
    } else {
        ret |= (inLoop && sb.requiresRecompilation());
    }
    return ret;
}
Also used : ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) FunctionStatement(org.apache.sysml.parser.FunctionStatement) IfStatement(org.apache.sysml.parser.IfStatement) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock)

Example 13 with WhileStatementBlock

use of org.apache.sysml.parser.WhileStatementBlock in project incubator-systemml by apache.

the class IPAPassPropagateReplaceLiterals method rReplaceLiterals.

private void rReplaceLiterals(StatementBlock sb, LocalVariableMap constants) {
    // remove updated literals
    for (String varname : sb.variablesUpdated().getVariableNames()) if (constants.keySet().contains(varname))
        constants.remove(varname);
    // propagate and replace literals
    if (sb instanceof WhileStatementBlock) {
        WhileStatementBlock wsb = (WhileStatementBlock) sb;
        WhileStatement ws = (WhileStatement) sb.getStatement(0);
        replaceLiterals(wsb.getPredicateHops(), constants);
        for (StatementBlock current : ws.getBody()) rReplaceLiterals(current, constants);
    } else if (sb instanceof IfStatementBlock) {
        IfStatementBlock isb = (IfStatementBlock) sb;
        IfStatement ifs = (IfStatement) sb.getStatement(0);
        replaceLiterals(isb.getPredicateHops(), constants);
        for (StatementBlock current : ifs.getIfBody()) rReplaceLiterals(current, constants);
        for (StatementBlock current : ifs.getElseBody()) rReplaceLiterals(current, constants);
    } else if (sb instanceof ForStatementBlock) {
        ForStatementBlock fsb = (ForStatementBlock) sb;
        ForStatement fs = (ForStatement) sb.getStatement(0);
        replaceLiterals(fsb.getFromHops(), constants);
        replaceLiterals(fsb.getToHops(), constants);
        replaceLiterals(fsb.getIncrementHops(), constants);
        for (StatementBlock current : fs.getBody()) rReplaceLiterals(current, constants);
    } else {
        replaceLiterals(sb.getHops(), constants);
    }
}
Also used : ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) IfStatement(org.apache.sysml.parser.IfStatement) WhileStatement(org.apache.sysml.parser.WhileStatement) ForStatement(org.apache.sysml.parser.ForStatement) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock)

Example 14 with WhileStatementBlock

use of org.apache.sysml.parser.WhileStatementBlock in project incubator-systemml by apache.

the class IPAPassRemoveUnnecessaryCheckpoints method removeCheckpointBeforeUpdate.

private static void removeCheckpointBeforeUpdate(DMLProgram dmlp) {
    // approach: scan over top-level program (guaranteed to be unconditional),
    // collect checkpoints; determine if used before update; remove first checkpoint
    // on second checkpoint if update in between and not used before update
    HashMap<String, Hop> chkpointCand = new HashMap<>();
    for (StatementBlock sb : dmlp.getStatementBlocks()) {
        // prune candidates (used before updated)
        Set<String> cands = new HashSet<>(chkpointCand.keySet());
        for (String cand : cands) if (sb.variablesRead().containsVariable(cand) && !sb.variablesUpdated().containsVariable(cand)) {
            // note: variableRead might include false positives due to meta
            // data operations like nrow(X) or operations removed by rewrites
            // double check hops on basic blocks; otherwise worst-case
            boolean skipRemove = false;
            if (sb.getHops() != null) {
                Hop.resetVisitStatus(sb.getHops());
                skipRemove = true;
                for (Hop root : sb.getHops()) skipRemove &= !HopRewriteUtils.rContainsRead(root, cand, false);
            }
            if (!skipRemove)
                chkpointCand.remove(cand);
        }
        // prune candidates (updated in conditional control flow)
        Set<String> cands2 = new HashSet<>(chkpointCand.keySet());
        if (sb instanceof IfStatementBlock || sb instanceof WhileStatementBlock || sb instanceof ForStatementBlock) {
            for (String cand : cands2) if (sb.variablesUpdated().containsVariable(cand)) {
                chkpointCand.remove(cand);
            }
        } else // prune candidates (updated w/ multiple reads)
        {
            for (String cand : cands2) if (sb.variablesUpdated().containsVariable(cand) && sb.getHops() != null) {
                Hop.resetVisitStatus(sb.getHops());
                for (Hop root : sb.getHops()) if (root.getName().equals(cand) && !HopRewriteUtils.rHasSimpleReadChain(root, cand)) {
                    chkpointCand.remove(cand);
                }
            }
        }
        // collect checkpoints and remove unnecessary checkpoints
        if (HopRewriteUtils.isLastLevelStatementBlock(sb)) {
            ArrayList<Hop> tmp = collectCheckpoints(sb.getHops());
            for (Hop chkpoint : tmp) {
                if (chkpointCand.containsKey(chkpoint.getName())) {
                    chkpointCand.get(chkpoint.getName()).setRequiresCheckpoint(false);
                }
                chkpointCand.put(chkpoint.getName(), chkpoint);
            }
        }
    }
}
Also used : ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) HashMap(java.util.HashMap) Hop(org.apache.sysml.hops.Hop) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) HashSet(java.util.HashSet) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock)

Example 15 with WhileStatementBlock

use of org.apache.sysml.parser.WhileStatementBlock in project incubator-systemml by apache.

the class Recompiler method rRecompileProgramBlock2Forced.

private static void rRecompileProgramBlock2Forced(ProgramBlock pb, long tid, HashSet<String> fnStack, ExecType et) {
    if (pb instanceof WhileProgramBlock) {
        WhileProgramBlock pbTmp = (WhileProgramBlock) pb;
        WhileStatementBlock sbTmp = (WhileStatementBlock) pbTmp.getStatementBlock();
        // recompile predicate
        if (sbTmp != null && !(et == ExecType.CP && !OptTreeConverter.containsMRJobInstruction(pbTmp.getPredicate(), true, true)))
            pbTmp.setPredicate(Recompiler.recompileHopsDag2Forced(sbTmp.getPredicateHops(), tid, et));
        // recompile body
        for (ProgramBlock pb2 : pbTmp.getChildBlocks()) rRecompileProgramBlock2Forced(pb2, tid, fnStack, et);
    } else if (pb instanceof IfProgramBlock) {
        IfProgramBlock pbTmp = (IfProgramBlock) pb;
        IfStatementBlock sbTmp = (IfStatementBlock) pbTmp.getStatementBlock();
        // recompile predicate
        if (sbTmp != null && !(et == ExecType.CP && !OptTreeConverter.containsMRJobInstruction(pbTmp.getPredicate(), true, true)))
            pbTmp.setPredicate(Recompiler.recompileHopsDag2Forced(sbTmp.getPredicateHops(), tid, et));
        // recompile body
        for (ProgramBlock pb2 : pbTmp.getChildBlocksIfBody()) rRecompileProgramBlock2Forced(pb2, tid, fnStack, et);
        for (ProgramBlock pb2 : pbTmp.getChildBlocksElseBody()) rRecompileProgramBlock2Forced(pb2, tid, fnStack, et);
    } else if (// includes ParFORProgramBlock
    pb instanceof ForProgramBlock) {
        ForProgramBlock pbTmp = (ForProgramBlock) pb;
        ForStatementBlock sbTmp = (ForStatementBlock) pbTmp.getStatementBlock();
        // recompile predicate
        if (sbTmp != null && sbTmp.getFromHops() != null && !(et == ExecType.CP && !OptTreeConverter.containsMRJobInstruction(pbTmp.getFromInstructions(), true, true)))
            pbTmp.setFromInstructions(Recompiler.recompileHopsDag2Forced(sbTmp.getFromHops(), tid, et));
        if (sbTmp != null && sbTmp.getToHops() != null && !(et == ExecType.CP && !OptTreeConverter.containsMRJobInstruction(pbTmp.getToInstructions(), true, true)))
            pbTmp.setToInstructions(Recompiler.recompileHopsDag2Forced(sbTmp.getToHops(), tid, et));
        if (sbTmp != null && sbTmp.getIncrementHops() != null && !(et == ExecType.CP && !OptTreeConverter.containsMRJobInstruction(pbTmp.getIncrementInstructions(), true, true)))
            pbTmp.setIncrementInstructions(Recompiler.recompileHopsDag2Forced(sbTmp.getIncrementHops(), tid, et));
        // recompile body
        for (ProgramBlock pb2 : pbTmp.getChildBlocks()) rRecompileProgramBlock2Forced(pb2, tid, fnStack, et);
    } else if (// includes ExternalFunctionProgramBlock and ExternalFunctionProgramBlockCP
    pb instanceof FunctionProgramBlock) {
        FunctionProgramBlock tmp = (FunctionProgramBlock) pb;
        for (ProgramBlock pb2 : tmp.getChildBlocks()) rRecompileProgramBlock2Forced(pb2, tid, fnStack, et);
    } else {
        StatementBlock sb = pb.getStatementBlock();
        // would be invalid with permutation matrix mult across multiple dags)
        if (sb != null) {
            ArrayList<Instruction> tmp = pb.getInstructions();
            tmp = Recompiler.recompileHopsDag2Forced(sb, sb.getHops(), tid, et);
            pb.setInstructions(tmp);
        }
        // recompile functions
        if (OptTreeConverter.containsFunctionCallInstruction(pb)) {
            ArrayList<Instruction> tmp = pb.getInstructions();
            for (Instruction inst : tmp) if (inst instanceof FunctionCallCPInstruction) {
                FunctionCallCPInstruction func = (FunctionCallCPInstruction) inst;
                String fname = func.getFunctionName();
                String fnamespace = func.getNamespace();
                String fKey = DMLProgram.constructFunctionKey(fnamespace, fname);
                if (// memoization for multiple calls, recursion
                !fnStack.contains(fKey)) {
                    fnStack.add(fKey);
                    FunctionProgramBlock fpb = pb.getProgram().getFunctionProgramBlock(fnamespace, fname);
                    // recompile chains of functions
                    rRecompileProgramBlock2Forced(fpb, tid, fnStack, et);
                }
            }
        }
    }
}
Also used : IfProgramBlock(org.apache.sysml.runtime.controlprogram.IfProgramBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) FunctionProgramBlock(org.apache.sysml.runtime.controlprogram.FunctionProgramBlock) ForProgramBlock(org.apache.sysml.runtime.controlprogram.ForProgramBlock) ParForProgramBlock(org.apache.sysml.runtime.controlprogram.ParForProgramBlock) ArrayList(java.util.ArrayList) WhileProgramBlock(org.apache.sysml.runtime.controlprogram.WhileProgramBlock) MRJobInstruction(org.apache.sysml.runtime.instructions.MRJobInstruction) Instruction(org.apache.sysml.runtime.instructions.Instruction) FunctionCallCPInstruction(org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction) SeqInstruction(org.apache.sysml.runtime.instructions.mr.SeqInstruction) RandInstruction(org.apache.sysml.runtime.instructions.mr.RandInstruction) FunctionCallCPInstruction(org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction) FunctionProgramBlock(org.apache.sysml.runtime.controlprogram.FunctionProgramBlock) WhileProgramBlock(org.apache.sysml.runtime.controlprogram.WhileProgramBlock) ForProgramBlock(org.apache.sysml.runtime.controlprogram.ForProgramBlock) IfProgramBlock(org.apache.sysml.runtime.controlprogram.IfProgramBlock) ProgramBlock(org.apache.sysml.runtime.controlprogram.ProgramBlock) ParForProgramBlock(org.apache.sysml.runtime.controlprogram.ParForProgramBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock)

Aggregations

WhileStatementBlock (org.apache.sysml.parser.WhileStatementBlock)41 ForStatementBlock (org.apache.sysml.parser.ForStatementBlock)38 IfStatementBlock (org.apache.sysml.parser.IfStatementBlock)36 StatementBlock (org.apache.sysml.parser.StatementBlock)36 ForStatement (org.apache.sysml.parser.ForStatement)21 IfStatement (org.apache.sysml.parser.IfStatement)21 FunctionStatementBlock (org.apache.sysml.parser.FunctionStatementBlock)20 WhileStatement (org.apache.sysml.parser.WhileStatement)20 Hop (org.apache.sysml.hops.Hop)17 ArrayList (java.util.ArrayList)15 FunctionStatement (org.apache.sysml.parser.FunctionStatement)14 WhileProgramBlock (org.apache.sysml.runtime.controlprogram.WhileProgramBlock)14 ForProgramBlock (org.apache.sysml.runtime.controlprogram.ForProgramBlock)13 IfProgramBlock (org.apache.sysml.runtime.controlprogram.IfProgramBlock)13 FunctionProgramBlock (org.apache.sysml.runtime.controlprogram.FunctionProgramBlock)9 ProgramBlock (org.apache.sysml.runtime.controlprogram.ProgramBlock)8 ParForStatementBlock (org.apache.sysml.parser.ParForStatementBlock)7 LocalVariableMap (org.apache.sysml.runtime.controlprogram.LocalVariableMap)7 ExternalFunctionStatement (org.apache.sysml.parser.ExternalFunctionStatement)6 HashMap (java.util.HashMap)5