Search in sources :

Example 1 with ProgramRewriteStatus

use of org.apache.sysml.hops.rewrite.ProgramRewriteStatus in project incubator-systemml by apache.

the class OptimizerRuleBased method rewriteInjectSparkLoopCheckpointing.

///////
//REWRITE inject spark loop checkpointing
///
protected void rewriteInjectSparkLoopCheckpointing(OptNode n) throws DMLRuntimeException {
    //get program blocks of root parfor
    Object[] progobj = OptTreeConverter.getAbstractPlanMapping().getMappedProg(n.getID());
    ParForStatementBlock pfsb = (ParForStatementBlock) progobj[0];
    ParForStatement fs = (ParForStatement) pfsb.getStatement(0);
    ParForProgramBlock pfpb = (ParForProgramBlock) progobj[1];
    boolean applied = false;
    try {
        //apply hop rewrite inject spark checkpoints (but without context awareness)
        RewriteInjectSparkLoopCheckpointing rewrite = new RewriteInjectSparkLoopCheckpointing(false);
        ProgramRewriter rewriter = new ProgramRewriter(rewrite);
        ProgramRewriteStatus state = new ProgramRewriteStatus();
        rewriter.rewriteStatementBlockHopDAGs(pfsb, state);
        fs.setBody(rewriter.rewriteStatementBlocks(fs.getBody(), state));
        //recompile if additional checkpoints introduced
        if (state.getInjectedCheckpoints()) {
            pfpb.setChildBlocks(ProgramRecompiler.generatePartitialRuntimeProgram(pfpb.getProgram(), fs.getBody()));
            applied = true;
        }
    } catch (Exception ex) {
        throw new DMLRuntimeException(ex);
    }
    LOG.debug(getOptMode() + " OPT: rewrite 'inject spark loop checkpointing' - result=" + applied);
}
Also used : ProgramRewriter(org.apache.sysml.hops.rewrite.ProgramRewriter) ParForStatementBlock(org.apache.sysml.parser.ParForStatementBlock) MatrixObject(org.apache.sysml.runtime.controlprogram.caching.MatrixObject) RDDObject(org.apache.sysml.runtime.instructions.spark.data.RDDObject) ParForStatement(org.apache.sysml.parser.ParForStatement) ProgramRewriteStatus(org.apache.sysml.hops.rewrite.ProgramRewriteStatus) RewriteInjectSparkLoopCheckpointing(org.apache.sysml.hops.rewrite.RewriteInjectSparkLoopCheckpointing) HopsException(org.apache.sysml.hops.HopsException) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) LopsException(org.apache.sysml.lops.LopsException) LanguageException(org.apache.sysml.parser.LanguageException) IOException(java.io.IOException) ParForProgramBlock(org.apache.sysml.runtime.controlprogram.ParForProgramBlock) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException)

Example 2 with ProgramRewriteStatus

use of org.apache.sysml.hops.rewrite.ProgramRewriteStatus in project incubator-systemml by apache.

the class OptimizationWrapper method optimize.

@SuppressWarnings("unused")
private static void optimize(POptMode otype, int ck, double cm, ParForStatementBlock sb, ParForProgramBlock pb, ExecutionContext ec, boolean monitor) throws DMLRuntimeException {
    Timing time = new Timing(true);
    //maintain statistics
    if (DMLScript.STATISTICS)
        Statistics.incrementParForOptimCount();
    //create specified optimizer
    Optimizer opt = createOptimizer(otype);
    CostModelType cmtype = opt.getCostModelType();
    LOG.trace("ParFOR Opt: Created optimizer (" + otype + "," + opt.getPlanInputType() + "," + opt.getCostModelType());
    OptTree tree = null;
    //recompile parfor body 
    if (ConfigurationManager.isDynamicRecompilation()) {
        ForStatement fs = (ForStatement) sb.getStatement(0);
        //debug output before recompilation
        if (LOG.isDebugEnabled()) {
            try {
                tree = OptTreeConverter.createOptTree(ck, cm, opt.getPlanInputType(), sb, pb, ec);
                LOG.debug("ParFOR Opt: Input plan (before recompilation):\n" + tree.explain(false));
                OptTreeConverter.clear();
            } catch (Exception ex) {
                throw new DMLRuntimeException("Unable to create opt tree.", ex);
            }
        }
        //separate propagation required because recompile in-place without literal replacement)
        try {
            LocalVariableMap constVars = ProgramRecompiler.getReusableScalarVariables(sb.getDMLProg(), sb, ec.getVariables());
            ProgramRecompiler.replaceConstantScalarVariables(sb, constVars);
        } catch (Exception ex) {
            throw new DMLRuntimeException(ex);
        }
        //program rewrites (e.g., constant folding, branch removal) according to replaced literals
        try {
            ProgramRewriter rewriter = createProgramRewriterWithRuleSets();
            ProgramRewriteStatus state = new ProgramRewriteStatus();
            rewriter.rewriteStatementBlockHopDAGs(sb, state);
            fs.setBody(rewriter.rewriteStatementBlocks(fs.getBody(), state));
            if (state.getRemovedBranches()) {
                LOG.debug("ParFOR Opt: Removed branches during program rewrites, rebuilding runtime program");
                pb.setChildBlocks(ProgramRecompiler.generatePartitialRuntimeProgram(pb.getProgram(), fs.getBody()));
            }
        } catch (Exception ex) {
            throw new DMLRuntimeException(ex);
        }
        //recompilation of parfor body and called functions (if safe)
        try {
            //core parfor body recompilation (based on symbol table entries)
            //* clone of variables in order to allow for statistics propagation across DAGs
            //(tid=0, because deep copies created after opt)
            LocalVariableMap tmp = (LocalVariableMap) ec.getVariables().clone();
            Recompiler.recompileProgramBlockHierarchy(pb.getChildBlocks(), tmp, 0, true);
            //inter-procedural optimization (based on previous recompilation)
            if (pb.hasFunctions()) {
                InterProceduralAnalysis ipa = new InterProceduralAnalysis();
                Set<String> fcand = ipa.analyzeSubProgram(sb);
                if (!fcand.isEmpty()) {
                    //regenerate runtime program of modified functions
                    for (String func : fcand) {
                        String[] funcparts = DMLProgram.splitFunctionKey(func);
                        FunctionProgramBlock fpb = pb.getProgram().getFunctionProgramBlock(funcparts[0], funcparts[1]);
                        //reset recompilation flags according to recompileOnce because it is only safe if function is recompileOnce 
                        //because then recompiled for every execution (otherwise potential issues if func also called outside parfor)
                        Recompiler.recompileProgramBlockHierarchy(fpb.getChildBlocks(), new LocalVariableMap(), 0, fpb.isRecompileOnce());
                    }
                }
            }
        } catch (Exception ex) {
            throw new DMLRuntimeException(ex);
        }
    }
    //create opt tree (before optimization)
    try {
        tree = OptTreeConverter.createOptTree(ck, cm, opt.getPlanInputType(), sb, pb, ec);
        LOG.debug("ParFOR Opt: Input plan (before optimization):\n" + tree.explain(false));
    } catch (Exception ex) {
        throw new DMLRuntimeException("Unable to create opt tree.", ex);
    }
    //create cost estimator
    CostEstimator est = createCostEstimator(cmtype, ec.getVariables());
    LOG.trace("ParFOR Opt: Created cost estimator (" + cmtype + ")");
    //core optimize
    opt.optimize(sb, pb, tree, est, ec);
    LOG.debug("ParFOR Opt: Optimized plan (after optimization): \n" + tree.explain(false));
    //assert plan correctness
    if (CHECK_PLAN_CORRECTNESS && LOG.isDebugEnabled()) {
        try {
            OptTreePlanChecker.checkProgramCorrectness(pb, sb, new HashSet<String>());
            LOG.debug("ParFOR Opt: Checked plan and program correctness.");
        } catch (Exception ex) {
            throw new DMLRuntimeException("Failed to check program correctness.", ex);
        }
    }
    long ltime = (long) time.stop();
    LOG.trace("ParFOR Opt: Optimized plan in " + ltime + "ms.");
    if (DMLScript.STATISTICS)
        Statistics.incrementParForOptimTime(ltime);
    //cleanup phase
    OptTreeConverter.clear();
    //monitor stats
    if (monitor) {
        StatisticMonitor.putPFStat(pb.getID(), Stat.OPT_OPTIMIZER, otype.ordinal());
        StatisticMonitor.putPFStat(pb.getID(), Stat.OPT_NUMTPLANS, opt.getNumTotalPlans());
        StatisticMonitor.putPFStat(pb.getID(), Stat.OPT_NUMEPLANS, opt.getNumEvaluatedPlans());
    }
}
Also used : FunctionProgramBlock(org.apache.sysml.runtime.controlprogram.FunctionProgramBlock) ProgramRewriter(org.apache.sysml.hops.rewrite.ProgramRewriter) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) LocalVariableMap(org.apache.sysml.runtime.controlprogram.LocalVariableMap) CostModelType(org.apache.sysml.runtime.controlprogram.parfor.opt.Optimizer.CostModelType) InterProceduralAnalysis(org.apache.sysml.hops.ipa.InterProceduralAnalysis) Timing(org.apache.sysml.runtime.controlprogram.parfor.stat.Timing) ForStatement(org.apache.sysml.parser.ForStatement) ProgramRewriteStatus(org.apache.sysml.hops.rewrite.ProgramRewriteStatus)

Example 3 with ProgramRewriteStatus

use of org.apache.sysml.hops.rewrite.ProgramRewriteStatus in project incubator-systemml by apache.

the class SpoofCompiler method optimize.

/**
	 * Main interface of sum-product optimizer, statement block dag.
	 * 
	 * @param roots dag root nodes
	 * @param recompile true if invoked during dynamic recompilation
	 * @return dag root nodes of modified dag 
	 * @throws DMLRuntimeException if optimization failed
	 */
public static ArrayList<Hop> optimize(ArrayList<Hop> roots, boolean recompile) throws DMLRuntimeException {
    if (roots == null || roots.isEmpty())
        return roots;
    long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
    ArrayList<Hop> ret = roots;
    try {
        //context-sensitive literal replacement (only integers during recompile)
        boolean compileLiterals = (PLAN_CACHE_POLICY == PlanCachePolicy.CONSTANT) || !recompile;
        //construct codegen plans
        HashMap<Long, Pair<Hop[], CNodeTpl>> cplans = constructCPlans(roots, compileLiterals);
        //cleanup codegen plans (remove unnecessary inputs, fix hop-cnodedata mapping,
        //remove empty templates with single cnodedata input, remove spurious lookups)
        cplans = cleanupCPlans(cplans);
        //explain before modification
        if (LOG.isTraceEnabled() && !cplans.isEmpty()) {
            //existing cplans
            LOG.trace("Codegen EXPLAIN (before optimize): \n" + Explain.explainHops(roots));
        }
        //source code generation for all cplans
        HashMap<Long, Pair<Hop[], Class<?>>> clas = new HashMap<Long, Pair<Hop[], Class<?>>>();
        for (Entry<Long, Pair<Hop[], CNodeTpl>> cplan : cplans.entrySet()) {
            Pair<Hop[], CNodeTpl> tmp = cplan.getValue();
            Class<?> cla = planCache.getPlan(tmp.getValue());
            if (cla == null) {
                //generate java source code
                String src = tmp.getValue().codegen(false);
                //explain debug output cplans or generated source code
                if (LOG.isTraceEnabled() || DMLScript.EXPLAIN.isHopsType(recompile)) {
                    LOG.info("Codegen EXPLAIN (generated cplan for HopID: " + cplan.getKey() + "):");
                    LOG.info(tmp.getValue().getClassname() + Explain.explainCPlan(cplan.getValue().getValue()));
                }
                if (LOG.isTraceEnabled() || DMLScript.EXPLAIN.isRuntimeType(recompile)) {
                    LOG.info("Codegen EXPLAIN (generated code for HopID: " + cplan.getKey() + "):");
                    LOG.info(src);
                }
                //compile generated java source code
                cla = CodegenUtils.compileClass("codegen." + tmp.getValue().getClassname(), src);
                //maintain plan cache
                if (PLAN_CACHE_POLICY != PlanCachePolicy.NONE)
                    planCache.putPlan(tmp.getValue(), cla);
            } else if (DMLScript.STATISTICS) {
                Statistics.incrementCodegenPlanCacheHits();
            }
            //make class available and maintain hits
            if (cla != null)
                clas.put(cplan.getKey(), new Pair<Hop[], Class<?>>(tmp.getKey(), cla));
            if (DMLScript.STATISTICS)
                Statistics.incrementCodegenPlanCacheTotal();
        }
        //create modified hop dag (operator replacement and CSE)
        if (!cplans.isEmpty()) {
            //generate final hop dag
            ret = constructModifiedHopDag(roots, cplans, clas);
            //run common subexpression elimination and other rewrites
            ret = rewriteCSE.rewriteHopDAGs(ret, new ProgramRewriteStatus());
            //explain after modification
            if (LOG.isTraceEnabled()) {
                LOG.trace("Codegen EXPLAIN (after optimize): \n" + Explain.explainHops(roots));
            }
        }
    } catch (Exception ex) {
        LOG.error("Codegen failed to optimize the following HOP DAG: \n" + Explain.explainHops(roots));
        throw new DMLRuntimeException(ex);
    }
    if (DMLScript.STATISTICS) {
        Statistics.incrementCodegenDAGCompile();
        Statistics.incrementCodegenCompileTime(System.nanoTime() - t0);
    }
    Hop.resetVisitStatus(roots);
    return ret;
}
Also used : CNodeTpl(org.apache.sysml.hops.codegen.cplan.CNodeTpl) HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) Hop(org.apache.sysml.hops.Hop) HopsException(org.apache.sysml.hops.HopsException) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) LopsException(org.apache.sysml.lops.LopsException) DMLException(org.apache.sysml.api.DMLException) LanguageException(org.apache.sysml.parser.LanguageException) IOException(java.io.IOException) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) ProgramRewriteStatus(org.apache.sysml.hops.rewrite.ProgramRewriteStatus) Pair(org.apache.sysml.runtime.matrix.data.Pair)

Aggregations

ProgramRewriteStatus (org.apache.sysml.hops.rewrite.ProgramRewriteStatus)3 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)3 IOException (java.io.IOException)2 HopsException (org.apache.sysml.hops.HopsException)2 ProgramRewriter (org.apache.sysml.hops.rewrite.ProgramRewriter)2 LopsException (org.apache.sysml.lops.LopsException)2 LanguageException (org.apache.sysml.parser.LanguageException)2 HashMap (java.util.HashMap)1 LinkedHashMap (java.util.LinkedHashMap)1 DMLException (org.apache.sysml.api.DMLException)1 Hop (org.apache.sysml.hops.Hop)1 CNodeTpl (org.apache.sysml.hops.codegen.cplan.CNodeTpl)1 InterProceduralAnalysis (org.apache.sysml.hops.ipa.InterProceduralAnalysis)1 RewriteInjectSparkLoopCheckpointing (org.apache.sysml.hops.rewrite.RewriteInjectSparkLoopCheckpointing)1 ForStatement (org.apache.sysml.parser.ForStatement)1 ParForStatement (org.apache.sysml.parser.ParForStatement)1 ParForStatementBlock (org.apache.sysml.parser.ParForStatementBlock)1 FunctionProgramBlock (org.apache.sysml.runtime.controlprogram.FunctionProgramBlock)1 LocalVariableMap (org.apache.sysml.runtime.controlprogram.LocalVariableMap)1 ParForProgramBlock (org.apache.sysml.runtime.controlprogram.ParForProgramBlock)1