Search in sources :

Example 1 with CostModelType

use of org.apache.sysml.runtime.controlprogram.parfor.opt.Optimizer.CostModelType in project systemml by apache.

the class OptimizationWrapper method optimize.

@SuppressWarnings("unused")
private static void optimize(POptMode otype, int ck, double cm, ParForStatementBlock sb, ParForProgramBlock pb, ExecutionContext ec, boolean monitor) {
    Timing time = new Timing(true);
    // maintain statistics
    if (DMLScript.STATISTICS)
        Statistics.incrementParForOptimCount();
    // create specified optimizer
    Optimizer opt = createOptimizer(otype);
    CostModelType cmtype = opt.getCostModelType();
    LOG.trace("ParFOR Opt: Created optimizer (" + otype + "," + opt.getPlanInputType() + "," + opt.getCostModelType());
    OptTree tree = null;
    // recompile parfor body
    if (ConfigurationManager.isDynamicRecompilation()) {
        ForStatement fs = (ForStatement) sb.getStatement(0);
        // debug output before recompilation
        if (LOG.isDebugEnabled()) {
            try {
                tree = OptTreeConverter.createOptTree(ck, cm, opt.getPlanInputType(), sb, pb, ec);
                LOG.debug("ParFOR Opt: Input plan (before recompilation):\n" + tree.explain(false));
                OptTreeConverter.clear();
            } catch (Exception ex) {
                throw new DMLRuntimeException("Unable to create opt tree.", ex);
            }
        }
        // separate propagation required because recompile in-place without literal replacement)
        try {
            LocalVariableMap constVars = ProgramRecompiler.getReusableScalarVariables(sb.getDMLProg(), sb, ec.getVariables());
            ProgramRecompiler.replaceConstantScalarVariables(sb, constVars);
        } catch (Exception ex) {
            throw new DMLRuntimeException(ex);
        }
        // program rewrites (e.g., constant folding, branch removal) according to replaced literals
        try {
            ProgramRewriter rewriter = createProgramRewriterWithRuleSets();
            ProgramRewriteStatus state = new ProgramRewriteStatus();
            rewriter.rRewriteStatementBlockHopDAGs(sb, state);
            fs.setBody(rewriter.rRewriteStatementBlocks(fs.getBody(), state, true));
            if (state.getRemovedBranches()) {
                LOG.debug("ParFOR Opt: Removed branches during program rewrites, rebuilding runtime program");
                pb.setChildBlocks(ProgramRecompiler.generatePartitialRuntimeProgram(pb.getProgram(), fs.getBody()));
            }
        } catch (Exception ex) {
            throw new DMLRuntimeException(ex);
        }
        // recompilation of parfor body and called functions (if safe)
        try {
            // core parfor body recompilation (based on symbol table entries)
            // * clone of variables in order to allow for statistics propagation across DAGs
            // (tid=0, because deep copies created after opt)
            LocalVariableMap tmp = (LocalVariableMap) ec.getVariables().clone();
            ResetType reset = ConfigurationManager.isCodegenEnabled() ? ResetType.RESET_KNOWN_DIMS : ResetType.RESET;
            Recompiler.recompileProgramBlockHierarchy(pb.getChildBlocks(), tmp, 0, reset);
            // inter-procedural optimization (based on previous recompilation)
            if (pb.hasFunctions()) {
                InterProceduralAnalysis ipa = new InterProceduralAnalysis(sb);
                Set<String> fcand = ipa.analyzeSubProgram();
                if (!fcand.isEmpty()) {
                    // regenerate runtime program of modified functions
                    for (String func : fcand) {
                        String[] funcparts = DMLProgram.splitFunctionKey(func);
                        FunctionProgramBlock fpb = pb.getProgram().getFunctionProgramBlock(funcparts[0], funcparts[1]);
                        // reset recompilation flags according to recompileOnce because it is only safe if function is recompileOnce
                        // because then recompiled for every execution (otherwise potential issues if func also called outside parfor)
                        ResetType reset2 = fpb.isRecompileOnce() ? reset : ResetType.NO_RESET;
                        Recompiler.recompileProgramBlockHierarchy(fpb.getChildBlocks(), new LocalVariableMap(), 0, reset2);
                    }
                }
            }
        } catch (Exception ex) {
            throw new DMLRuntimeException(ex);
        }
    }
    // create opt tree (before optimization)
    try {
        tree = OptTreeConverter.createOptTree(ck, cm, opt.getPlanInputType(), sb, pb, ec);
        LOG.debug("ParFOR Opt: Input plan (before optimization):\n" + tree.explain(false));
    } catch (Exception ex) {
        throw new DMLRuntimeException("Unable to create opt tree.", ex);
    }
    // create cost estimator
    CostEstimator est = createCostEstimator(cmtype, ec.getVariables());
    LOG.trace("ParFOR Opt: Created cost estimator (" + cmtype + ")");
    // core optimize
    opt.optimize(sb, pb, tree, est, ec);
    LOG.debug("ParFOR Opt: Optimized plan (after optimization): \n" + tree.explain(false));
    // assert plan correctness
    if (CHECK_PLAN_CORRECTNESS && LOG.isDebugEnabled()) {
        try {
            OptTreePlanChecker.checkProgramCorrectness(pb, sb, new HashSet<String>());
            LOG.debug("ParFOR Opt: Checked plan and program correctness.");
        } catch (Exception ex) {
            throw new DMLRuntimeException("Failed to check program correctness.", ex);
        }
    }
    long ltime = (long) time.stop();
    LOG.trace("ParFOR Opt: Optimized plan in " + ltime + "ms.");
    if (DMLScript.STATISTICS)
        Statistics.incrementParForOptimTime(ltime);
    // cleanup phase
    OptTreeConverter.clear();
    // monitor stats
    if (monitor) {
        StatisticMonitor.putPFStat(pb.getID(), Stat.OPT_OPTIMIZER, otype.ordinal());
        StatisticMonitor.putPFStat(pb.getID(), Stat.OPT_NUMTPLANS, opt.getNumTotalPlans());
        StatisticMonitor.putPFStat(pb.getID(), Stat.OPT_NUMEPLANS, opt.getNumEvaluatedPlans());
    }
}
Also used : FunctionProgramBlock(org.apache.sysml.runtime.controlprogram.FunctionProgramBlock) ProgramRewriter(org.apache.sysml.hops.rewrite.ProgramRewriter) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) LocalVariableMap(org.apache.sysml.runtime.controlprogram.LocalVariableMap) CostModelType(org.apache.sysml.runtime.controlprogram.parfor.opt.Optimizer.CostModelType) InterProceduralAnalysis(org.apache.sysml.hops.ipa.InterProceduralAnalysis) ResetType(org.apache.sysml.hops.recompile.Recompiler.ResetType) Timing(org.apache.sysml.runtime.controlprogram.parfor.stat.Timing) ForStatement(org.apache.sysml.parser.ForStatement) ProgramRewriteStatus(org.apache.sysml.hops.rewrite.ProgramRewriteStatus)

Example 2 with CostModelType

use of org.apache.sysml.runtime.controlprogram.parfor.opt.Optimizer.CostModelType in project incubator-systemml by apache.

the class OptimizationWrapper method optimize.

@SuppressWarnings("unused")
private static void optimize(POptMode otype, int ck, double cm, ParForStatementBlock sb, ParForProgramBlock pb, ExecutionContext ec, boolean monitor) {
    Timing time = new Timing(true);
    // maintain statistics
    if (DMLScript.STATISTICS)
        Statistics.incrementParForOptimCount();
    // create specified optimizer
    Optimizer opt = createOptimizer(otype);
    CostModelType cmtype = opt.getCostModelType();
    LOG.trace("ParFOR Opt: Created optimizer (" + otype + "," + opt.getPlanInputType() + "," + opt.getCostModelType());
    OptTree tree = null;
    // recompile parfor body
    if (ConfigurationManager.isDynamicRecompilation()) {
        ForStatement fs = (ForStatement) sb.getStatement(0);
        // debug output before recompilation
        if (LOG.isDebugEnabled()) {
            try {
                tree = OptTreeConverter.createOptTree(ck, cm, opt.getPlanInputType(), sb, pb, ec);
                LOG.debug("ParFOR Opt: Input plan (before recompilation):\n" + tree.explain(false));
                OptTreeConverter.clear();
            } catch (Exception ex) {
                throw new DMLRuntimeException("Unable to create opt tree.", ex);
            }
        }
        // separate propagation required because recompile in-place without literal replacement)
        try {
            LocalVariableMap constVars = ProgramRecompiler.getReusableScalarVariables(sb.getDMLProg(), sb, ec.getVariables());
            ProgramRecompiler.replaceConstantScalarVariables(sb, constVars);
        } catch (Exception ex) {
            throw new DMLRuntimeException(ex);
        }
        // program rewrites (e.g., constant folding, branch removal) according to replaced literals
        try {
            ProgramRewriter rewriter = createProgramRewriterWithRuleSets();
            ProgramRewriteStatus state = new ProgramRewriteStatus();
            rewriter.rRewriteStatementBlockHopDAGs(sb, state);
            fs.setBody(rewriter.rRewriteStatementBlocks(fs.getBody(), state, true));
            if (state.getRemovedBranches()) {
                LOG.debug("ParFOR Opt: Removed branches during program rewrites, rebuilding runtime program");
                pb.setChildBlocks(ProgramRecompiler.generatePartitialRuntimeProgram(pb.getProgram(), fs.getBody()));
            }
        } catch (Exception ex) {
            throw new DMLRuntimeException(ex);
        }
        // recompilation of parfor body and called functions (if safe)
        try {
            // core parfor body recompilation (based on symbol table entries)
            // * clone of variables in order to allow for statistics propagation across DAGs
            // (tid=0, because deep copies created after opt)
            LocalVariableMap tmp = (LocalVariableMap) ec.getVariables().clone();
            ResetType reset = ConfigurationManager.isCodegenEnabled() ? ResetType.RESET_KNOWN_DIMS : ResetType.RESET;
            Recompiler.recompileProgramBlockHierarchy(pb.getChildBlocks(), tmp, 0, reset);
            // inter-procedural optimization (based on previous recompilation)
            if (pb.hasFunctions()) {
                InterProceduralAnalysis ipa = new InterProceduralAnalysis(sb);
                Set<String> fcand = ipa.analyzeSubProgram();
                if (!fcand.isEmpty()) {
                    // regenerate runtime program of modified functions
                    for (String func : fcand) {
                        String[] funcparts = DMLProgram.splitFunctionKey(func);
                        FunctionProgramBlock fpb = pb.getProgram().getFunctionProgramBlock(funcparts[0], funcparts[1]);
                        // reset recompilation flags according to recompileOnce because it is only safe if function is recompileOnce
                        // because then recompiled for every execution (otherwise potential issues if func also called outside parfor)
                        ResetType reset2 = fpb.isRecompileOnce() ? reset : ResetType.NO_RESET;
                        Recompiler.recompileProgramBlockHierarchy(fpb.getChildBlocks(), new LocalVariableMap(), 0, reset2);
                    }
                }
            }
        } catch (Exception ex) {
            throw new DMLRuntimeException(ex);
        }
    }
    // create opt tree (before optimization)
    try {
        tree = OptTreeConverter.createOptTree(ck, cm, opt.getPlanInputType(), sb, pb, ec);
        LOG.debug("ParFOR Opt: Input plan (before optimization):\n" + tree.explain(false));
    } catch (Exception ex) {
        throw new DMLRuntimeException("Unable to create opt tree.", ex);
    }
    // create cost estimator
    CostEstimator est = createCostEstimator(cmtype, ec.getVariables());
    LOG.trace("ParFOR Opt: Created cost estimator (" + cmtype + ")");
    // core optimize
    opt.optimize(sb, pb, tree, est, ec);
    LOG.debug("ParFOR Opt: Optimized plan (after optimization): \n" + tree.explain(false));
    // assert plan correctness
    if (CHECK_PLAN_CORRECTNESS && LOG.isDebugEnabled()) {
        try {
            OptTreePlanChecker.checkProgramCorrectness(pb, sb, new HashSet<String>());
            LOG.debug("ParFOR Opt: Checked plan and program correctness.");
        } catch (Exception ex) {
            throw new DMLRuntimeException("Failed to check program correctness.", ex);
        }
    }
    long ltime = (long) time.stop();
    LOG.trace("ParFOR Opt: Optimized plan in " + ltime + "ms.");
    if (DMLScript.STATISTICS)
        Statistics.incrementParForOptimTime(ltime);
    // cleanup phase
    OptTreeConverter.clear();
    // monitor stats
    if (monitor) {
        StatisticMonitor.putPFStat(pb.getID(), Stat.OPT_OPTIMIZER, otype.ordinal());
        StatisticMonitor.putPFStat(pb.getID(), Stat.OPT_NUMTPLANS, opt.getNumTotalPlans());
        StatisticMonitor.putPFStat(pb.getID(), Stat.OPT_NUMEPLANS, opt.getNumEvaluatedPlans());
    }
}
Also used : FunctionProgramBlock(org.apache.sysml.runtime.controlprogram.FunctionProgramBlock) ProgramRewriter(org.apache.sysml.hops.rewrite.ProgramRewriter) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) LocalVariableMap(org.apache.sysml.runtime.controlprogram.LocalVariableMap) CostModelType(org.apache.sysml.runtime.controlprogram.parfor.opt.Optimizer.CostModelType) InterProceduralAnalysis(org.apache.sysml.hops.ipa.InterProceduralAnalysis) ResetType(org.apache.sysml.hops.recompile.Recompiler.ResetType) Timing(org.apache.sysml.runtime.controlprogram.parfor.stat.Timing) ForStatement(org.apache.sysml.parser.ForStatement) ProgramRewriteStatus(org.apache.sysml.hops.rewrite.ProgramRewriteStatus)

Aggregations

InterProceduralAnalysis (org.apache.sysml.hops.ipa.InterProceduralAnalysis)2 ResetType (org.apache.sysml.hops.recompile.Recompiler.ResetType)2 ProgramRewriteStatus (org.apache.sysml.hops.rewrite.ProgramRewriteStatus)2 ProgramRewriter (org.apache.sysml.hops.rewrite.ProgramRewriter)2 ForStatement (org.apache.sysml.parser.ForStatement)2 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)2 FunctionProgramBlock (org.apache.sysml.runtime.controlprogram.FunctionProgramBlock)2 LocalVariableMap (org.apache.sysml.runtime.controlprogram.LocalVariableMap)2 CostModelType (org.apache.sysml.runtime.controlprogram.parfor.opt.Optimizer.CostModelType)2 Timing (org.apache.sysml.runtime.controlprogram.parfor.stat.Timing)2