Search in sources :

Example 6 with ForStatement

use of org.apache.sysml.parser.ForStatement in project incubator-systemml by apache.

the class ProgramRewriter method rewriteStatementBlock.

private ArrayList<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus status) throws HopsException {
    ArrayList<StatementBlock> ret = new ArrayList<StatementBlock>();
    ret.add(sb);
    //recursive invocation
    if (sb instanceof FunctionStatementBlock) {
        FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
        FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
        fstmt.setBody(rewriteStatementBlocks(fstmt.getBody(), status));
    } else if (sb instanceof WhileStatementBlock) {
        WhileStatementBlock wsb = (WhileStatementBlock) sb;
        WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
        wstmt.setBody(rewriteStatementBlocks(wstmt.getBody(), status));
    } else if (sb instanceof IfStatementBlock) {
        IfStatementBlock isb = (IfStatementBlock) sb;
        IfStatement istmt = (IfStatement) isb.getStatement(0);
        istmt.setIfBody(rewriteStatementBlocks(istmt.getIfBody(), status));
        istmt.setElseBody(rewriteStatementBlocks(istmt.getElseBody(), status));
    } else if (//incl parfor
    sb instanceof ForStatementBlock) {
        //maintain parfor context information (e.g., for checkpointing)
        boolean prestatus = status.isInParforContext();
        if (sb instanceof ParForStatementBlock)
            status.setInParforContext(true);
        ForStatementBlock fsb = (ForStatementBlock) sb;
        ForStatement fstmt = (ForStatement) fsb.getStatement(0);
        fstmt.setBody(rewriteStatementBlocks(fstmt.getBody(), status));
        status.setInParforContext(prestatus);
    }
    //apply rewrite rules
    for (StatementBlockRewriteRule r : _sbRuleSet) {
        ArrayList<StatementBlock> tmp = new ArrayList<StatementBlock>();
        for (StatementBlock sbc : ret) tmp.addAll(r.rewriteStatementBlock(sbc, status));
        //take over set of rewritten sbs		
        ret.clear();
        ret.addAll(tmp);
    }
    return ret;
}
Also used : ParForStatementBlock(org.apache.sysml.parser.ParForStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) ArrayList(java.util.ArrayList) WhileStatement(org.apache.sysml.parser.WhileStatement) FunctionStatement(org.apache.sysml.parser.FunctionStatement) IfStatement(org.apache.sysml.parser.IfStatement) ParForStatementBlock(org.apache.sysml.parser.ParForStatementBlock) ForStatement(org.apache.sysml.parser.ForStatement) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) ParForStatementBlock(org.apache.sysml.parser.ParForStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock)

Example 7 with ForStatement

use of org.apache.sysml.parser.ForStatement in project incubator-systemml by apache.

the class InterProceduralAnalysis method getFunctionCandidatesForStatisticPropagation.

/////////////////////////////
// GET FUNCTION CANDIDATES
//////
private void getFunctionCandidatesForStatisticPropagation(StatementBlock sb, Map<String, Integer> fcandCounts, Map<String, FunctionOp> fcandHops) throws HopsException, ParseException {
    if (sb instanceof FunctionStatementBlock) {
        FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
        FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
        for (StatementBlock sbi : fstmt.getBody()) getFunctionCandidatesForStatisticPropagation(sbi, fcandCounts, fcandHops);
    } else if (sb instanceof WhileStatementBlock) {
        WhileStatementBlock wsb = (WhileStatementBlock) sb;
        WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
        for (StatementBlock sbi : wstmt.getBody()) getFunctionCandidatesForStatisticPropagation(sbi, fcandCounts, fcandHops);
    } else if (sb instanceof IfStatementBlock) {
        IfStatementBlock isb = (IfStatementBlock) sb;
        IfStatement istmt = (IfStatement) isb.getStatement(0);
        for (StatementBlock sbi : istmt.getIfBody()) getFunctionCandidatesForStatisticPropagation(sbi, fcandCounts, fcandHops);
        for (StatementBlock sbi : istmt.getElseBody()) getFunctionCandidatesForStatisticPropagation(sbi, fcandCounts, fcandHops);
    } else if (//incl parfor
    sb instanceof ForStatementBlock) {
        ForStatementBlock fsb = (ForStatementBlock) sb;
        ForStatement fstmt = (ForStatement) fsb.getStatement(0);
        for (StatementBlock sbi : fstmt.getBody()) getFunctionCandidatesForStatisticPropagation(sbi, fcandCounts, fcandHops);
    } else //generic (last-level)
    {
        ArrayList<Hop> roots = sb.get_hops();
        if (//empty statement blocks
        roots != null)
            for (Hop root : roots) getFunctionCandidatesForStatisticPropagation(sb.getDMLProg(), root, fcandCounts, fcandHops);
    }
}
Also used : ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) ExternalFunctionStatement(org.apache.sysml.parser.ExternalFunctionStatement) FunctionStatement(org.apache.sysml.parser.FunctionStatement) IfStatement(org.apache.sysml.parser.IfStatement) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) ArrayList(java.util.ArrayList) Hop(org.apache.sysml.hops.Hop) WhileStatement(org.apache.sysml.parser.WhileStatement) ForStatement(org.apache.sysml.parser.ForStatement) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock)

Example 8 with ForStatement

use of org.apache.sysml.parser.ForStatement in project incubator-systemml by apache.

the class InterProceduralAnalysis method rRemoveConstantBinaryOp.

private void rRemoveConstantBinaryOp(StatementBlock sb, HashMap<String, Hop> mOnes) throws HopsException {
    if (sb instanceof IfStatementBlock) {
        IfStatementBlock isb = (IfStatementBlock) sb;
        IfStatement istmt = (IfStatement) isb.getStatement(0);
        for (StatementBlock c : istmt.getIfBody()) rRemoveConstantBinaryOp(c, mOnes);
        if (istmt.getElseBody() != null)
            for (StatementBlock c : istmt.getElseBody()) rRemoveConstantBinaryOp(c, mOnes);
    } else if (sb instanceof WhileStatementBlock) {
        WhileStatementBlock wsb = (WhileStatementBlock) sb;
        WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
        for (StatementBlock c : wstmt.getBody()) rRemoveConstantBinaryOp(c, mOnes);
    } else if (sb instanceof ForStatementBlock) {
        ForStatementBlock fsb = (ForStatementBlock) sb;
        ForStatement fstmt = (ForStatement) fsb.getStatement(0);
        for (StatementBlock c : fstmt.getBody()) rRemoveConstantBinaryOp(c, mOnes);
    } else {
        if (sb.get_hops() != null) {
            Hop.resetVisitStatus(sb.get_hops());
            for (Hop hop : sb.get_hops()) rRemoveConstantBinaryOp(hop, mOnes);
        }
    }
}
Also used : ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) IfStatement(org.apache.sysml.parser.IfStatement) Hop(org.apache.sysml.hops.Hop) WhileStatement(org.apache.sysml.parser.WhileStatement) ForStatement(org.apache.sysml.parser.ForStatement) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock)

Example 9 with ForStatement

use of org.apache.sysml.parser.ForStatement in project incubator-systemml by apache.

the class ProgramRecompiler method replaceConstantScalarVariables.

public static void replaceConstantScalarVariables(StatementBlock sb, LocalVariableMap vars) throws DMLRuntimeException, HopsException {
    if (sb instanceof IfStatementBlock) {
        IfStatementBlock isb = (IfStatementBlock) sb;
        IfStatement is = (IfStatement) sb.getStatement(0);
        replacePredicateLiterals(isb.getPredicateHops(), vars);
        for (StatementBlock lsb : is.getIfBody()) replaceConstantScalarVariables(lsb, vars);
        for (StatementBlock lsb : is.getElseBody()) replaceConstantScalarVariables(lsb, vars);
    } else if (sb instanceof WhileStatementBlock) {
        WhileStatementBlock wsb = (WhileStatementBlock) sb;
        WhileStatement ws = (WhileStatement) sb.getStatement(0);
        replacePredicateLiterals(wsb.getPredicateHops(), vars);
        for (StatementBlock lsb : ws.getBody()) replaceConstantScalarVariables(lsb, vars);
    } else if (//for or parfor
    sb instanceof ForStatementBlock) {
        ForStatementBlock fsb = (ForStatementBlock) sb;
        ForStatement fs = (ForStatement) fsb.getStatement(0);
        replacePredicateLiterals(fsb.getFromHops(), vars);
        replacePredicateLiterals(fsb.getToHops(), vars);
        replacePredicateLiterals(fsb.getIncrementHops(), vars);
        for (StatementBlock lsb : fs.getBody()) replaceConstantScalarVariables(lsb, vars);
    } else //last level block
    {
        ArrayList<Hop> hops = sb.get_hops();
        if (hops != null) {
            //replace constant literals
            Hop.resetVisitStatus(hops);
            for (Hop hopRoot : hops) Recompiler.rReplaceLiterals(hopRoot, vars, true);
        }
    }
}
Also used : ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) IfStatement(org.apache.sysml.parser.IfStatement) Hop(org.apache.sysml.hops.Hop) WhileStatement(org.apache.sysml.parser.WhileStatement) ForStatement(org.apache.sysml.parser.ForStatement) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock)

Example 10 with ForStatement

use of org.apache.sysml.parser.ForStatement in project incubator-systemml by apache.

the class OptimizationWrapper method optimize.

@SuppressWarnings("unused")
private static void optimize(POptMode otype, int ck, double cm, ParForStatementBlock sb, ParForProgramBlock pb, ExecutionContext ec, boolean monitor) throws DMLRuntimeException {
    Timing time = new Timing(true);
    //maintain statistics
    if (DMLScript.STATISTICS)
        Statistics.incrementParForOptimCount();
    //create specified optimizer
    Optimizer opt = createOptimizer(otype);
    CostModelType cmtype = opt.getCostModelType();
    LOG.trace("ParFOR Opt: Created optimizer (" + otype + "," + opt.getPlanInputType() + "," + opt.getCostModelType());
    OptTree tree = null;
    //recompile parfor body 
    if (ConfigurationManager.isDynamicRecompilation()) {
        ForStatement fs = (ForStatement) sb.getStatement(0);
        //debug output before recompilation
        if (LOG.isDebugEnabled()) {
            try {
                tree = OptTreeConverter.createOptTree(ck, cm, opt.getPlanInputType(), sb, pb, ec);
                LOG.debug("ParFOR Opt: Input plan (before recompilation):\n" + tree.explain(false));
                OptTreeConverter.clear();
            } catch (Exception ex) {
                throw new DMLRuntimeException("Unable to create opt tree.", ex);
            }
        }
        //separate propagation required because recompile in-place without literal replacement)
        try {
            LocalVariableMap constVars = ProgramRecompiler.getReusableScalarVariables(sb.getDMLProg(), sb, ec.getVariables());
            ProgramRecompiler.replaceConstantScalarVariables(sb, constVars);
        } catch (Exception ex) {
            throw new DMLRuntimeException(ex);
        }
        //program rewrites (e.g., constant folding, branch removal) according to replaced literals
        try {
            ProgramRewriter rewriter = createProgramRewriterWithRuleSets();
            ProgramRewriteStatus state = new ProgramRewriteStatus();
            rewriter.rewriteStatementBlockHopDAGs(sb, state);
            fs.setBody(rewriter.rewriteStatementBlocks(fs.getBody(), state));
            if (state.getRemovedBranches()) {
                LOG.debug("ParFOR Opt: Removed branches during program rewrites, rebuilding runtime program");
                pb.setChildBlocks(ProgramRecompiler.generatePartitialRuntimeProgram(pb.getProgram(), fs.getBody()));
            }
        } catch (Exception ex) {
            throw new DMLRuntimeException(ex);
        }
        //recompilation of parfor body and called functions (if safe)
        try {
            //core parfor body recompilation (based on symbol table entries)
            //* clone of variables in order to allow for statistics propagation across DAGs
            //(tid=0, because deep copies created after opt)
            LocalVariableMap tmp = (LocalVariableMap) ec.getVariables().clone();
            Recompiler.recompileProgramBlockHierarchy(pb.getChildBlocks(), tmp, 0, true);
            //inter-procedural optimization (based on previous recompilation)
            if (pb.hasFunctions()) {
                InterProceduralAnalysis ipa = new InterProceduralAnalysis();
                Set<String> fcand = ipa.analyzeSubProgram(sb);
                if (!fcand.isEmpty()) {
                    //regenerate runtime program of modified functions
                    for (String func : fcand) {
                        String[] funcparts = DMLProgram.splitFunctionKey(func);
                        FunctionProgramBlock fpb = pb.getProgram().getFunctionProgramBlock(funcparts[0], funcparts[1]);
                        //reset recompilation flags according to recompileOnce because it is only safe if function is recompileOnce 
                        //because then recompiled for every execution (otherwise potential issues if func also called outside parfor)
                        Recompiler.recompileProgramBlockHierarchy(fpb.getChildBlocks(), new LocalVariableMap(), 0, fpb.isRecompileOnce());
                    }
                }
            }
        } catch (Exception ex) {
            throw new DMLRuntimeException(ex);
        }
    }
    //create opt tree (before optimization)
    try {
        tree = OptTreeConverter.createOptTree(ck, cm, opt.getPlanInputType(), sb, pb, ec);
        LOG.debug("ParFOR Opt: Input plan (before optimization):\n" + tree.explain(false));
    } catch (Exception ex) {
        throw new DMLRuntimeException("Unable to create opt tree.", ex);
    }
    //create cost estimator
    CostEstimator est = createCostEstimator(cmtype, ec.getVariables());
    LOG.trace("ParFOR Opt: Created cost estimator (" + cmtype + ")");
    //core optimize
    opt.optimize(sb, pb, tree, est, ec);
    LOG.debug("ParFOR Opt: Optimized plan (after optimization): \n" + tree.explain(false));
    //assert plan correctness
    if (CHECK_PLAN_CORRECTNESS && LOG.isDebugEnabled()) {
        try {
            OptTreePlanChecker.checkProgramCorrectness(pb, sb, new HashSet<String>());
            LOG.debug("ParFOR Opt: Checked plan and program correctness.");
        } catch (Exception ex) {
            throw new DMLRuntimeException("Failed to check program correctness.", ex);
        }
    }
    long ltime = (long) time.stop();
    LOG.trace("ParFOR Opt: Optimized plan in " + ltime + "ms.");
    if (DMLScript.STATISTICS)
        Statistics.incrementParForOptimTime(ltime);
    //cleanup phase
    OptTreeConverter.clear();
    //monitor stats
    if (monitor) {
        StatisticMonitor.putPFStat(pb.getID(), Stat.OPT_OPTIMIZER, otype.ordinal());
        StatisticMonitor.putPFStat(pb.getID(), Stat.OPT_NUMTPLANS, opt.getNumTotalPlans());
        StatisticMonitor.putPFStat(pb.getID(), Stat.OPT_NUMEPLANS, opt.getNumEvaluatedPlans());
    }
}
Also used : FunctionProgramBlock(org.apache.sysml.runtime.controlprogram.FunctionProgramBlock) ProgramRewriter(org.apache.sysml.hops.rewrite.ProgramRewriter) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) LocalVariableMap(org.apache.sysml.runtime.controlprogram.LocalVariableMap) CostModelType(org.apache.sysml.runtime.controlprogram.parfor.opt.Optimizer.CostModelType) InterProceduralAnalysis(org.apache.sysml.hops.ipa.InterProceduralAnalysis) Timing(org.apache.sysml.runtime.controlprogram.parfor.stat.Timing) ForStatement(org.apache.sysml.parser.ForStatement) ProgramRewriteStatus(org.apache.sysml.hops.rewrite.ProgramRewriteStatus)

Aggregations

ForStatement (org.apache.sysml.parser.ForStatement)17 ForStatementBlock (org.apache.sysml.parser.ForStatementBlock)14 IfStatementBlock (org.apache.sysml.parser.IfStatementBlock)14 StatementBlock (org.apache.sysml.parser.StatementBlock)14 WhileStatementBlock (org.apache.sysml.parser.WhileStatementBlock)14 WhileStatement (org.apache.sysml.parser.WhileStatement)13 IfStatement (org.apache.sysml.parser.IfStatement)12 FunctionStatementBlock (org.apache.sysml.parser.FunctionStatementBlock)9 FunctionStatement (org.apache.sysml.parser.FunctionStatement)7 ArrayList (java.util.ArrayList)6 Hop (org.apache.sysml.hops.Hop)6 ParForStatementBlock (org.apache.sysml.parser.ParForStatementBlock)4 ExternalFunctionStatement (org.apache.sysml.parser.ExternalFunctionStatement)3 ParForStatement (org.apache.sysml.parser.ParForStatement)3 ForProgramBlock (org.apache.sysml.runtime.controlprogram.ForProgramBlock)3 FunctionProgramBlock (org.apache.sysml.runtime.controlprogram.FunctionProgramBlock)3 IfProgramBlock (org.apache.sysml.runtime.controlprogram.IfProgramBlock)3 LocalVariableMap (org.apache.sysml.runtime.controlprogram.LocalVariableMap)3 ProgramBlock (org.apache.sysml.runtime.controlprogram.ProgramBlock)3 WhileProgramBlock (org.apache.sysml.runtime.controlprogram.WhileProgramBlock)3