Search in sources :

Example 1 with ProgramRewriter

use of org.apache.sysml.hops.rewrite.ProgramRewriter in project incubator-systemml by apache.

the class Connection method prepareScript.

/**
	 * Prepares (precompiles) a script, sets input parameter values, and registers input and output variables.
	 * 
	 * @param script string representing the DML or PyDML script
	 * @param args map of input parameters ($) and their values
	 * @param inputs string array of input variables to register
	 * @param outputs string array of output variables to register
	 * @param parsePyDML {@code true} if PyDML, {@code false} if DML
	 * @return PreparedScript object representing the precompiled script
	 * @throws DMLException if DMLException occurs
	 */
public PreparedScript prepareScript(String script, Map<String, String> args, String[] inputs, String[] outputs, boolean parsePyDML) throws DMLException {
    DMLScript.SCRIPT_TYPE = parsePyDML ? ScriptType.PYDML : ScriptType.DML;
    //prepare arguments
    //simplified compilation chain
    Program rtprog = null;
    try {
        //parsing
        ParserWrapper parser = ParserFactory.createParser(parsePyDML ? ScriptType.PYDML : ScriptType.DML);
        DMLProgram prog = parser.parse(null, script, args);
        //language validate
        DMLTranslator dmlt = new DMLTranslator(prog);
        dmlt.liveVariableAnalysis(prog);
        dmlt.validateParseTree(prog);
        //hop construct/rewrite
        dmlt.constructHops(prog);
        dmlt.rewriteHopsDAG(prog);
        //rewrite persistent reads/writes
        RewriteRemovePersistentReadWrite rewrite = new RewriteRemovePersistentReadWrite(inputs, outputs);
        ProgramRewriter rewriter2 = new ProgramRewriter(rewrite);
        rewriter2.rewriteProgramHopDAGs(prog);
        //lop construct and runtime prog generation
        dmlt.constructLops(prog);
        rtprog = prog.getRuntimeProgram(_dmlconf);
        //final cleanup runtime prog
        JMLCUtils.cleanupRuntimeProgram(rtprog, outputs);
    //System.out.println(Explain.explain(rtprog));
    } catch (ParseException pe) {
        // don't chain ParseException (for cleaner error output)
        throw pe;
    } catch (Exception ex) {
        throw new DMLException(ex);
    }
    //return newly create precompiled script 
    return new PreparedScript(rtprog, inputs, outputs);
}
Also used : ProgramRewriter(org.apache.sysml.hops.rewrite.ProgramRewriter) DMLProgram(org.apache.sysml.parser.DMLProgram) Program(org.apache.sysml.runtime.controlprogram.Program) DMLException(org.apache.sysml.api.DMLException) DMLProgram(org.apache.sysml.parser.DMLProgram) ParserWrapper(org.apache.sysml.parser.ParserWrapper) ParseException(org.apache.sysml.parser.ParseException) DMLTranslator(org.apache.sysml.parser.DMLTranslator) RewriteRemovePersistentReadWrite(org.apache.sysml.hops.rewrite.RewriteRemovePersistentReadWrite) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) DMLException(org.apache.sysml.api.DMLException) IOException(java.io.IOException) ParseException(org.apache.sysml.parser.ParseException)

Example 2 with ProgramRewriter

use of org.apache.sysml.hops.rewrite.ProgramRewriter in project incubator-systemml by apache.

the class ScriptExecutor method rewritePersistentReadsAndWrites.

/**
	 * Replace persistent reads and writes with transient reads and writes in
	 * the symbol table.
	 */
protected void rewritePersistentReadsAndWrites() {
    LocalVariableMap symbolTable = script.getSymbolTable();
    if (symbolTable != null) {
        String[] inputs = (script.getInputVariables() == null) ? new String[0] : script.getInputVariables().toArray(new String[0]);
        String[] outputs = (script.getOutputVariables() == null) ? new String[0] : script.getOutputVariables().toArray(new String[0]);
        RewriteRemovePersistentReadWrite rewrite = new RewriteRemovePersistentReadWrite(inputs, outputs, script.getSymbolTable());
        ProgramRewriter programRewriter = new ProgramRewriter(rewrite);
        try {
            programRewriter.rewriteProgramHopDAGs(dmlProgram);
        } catch (LanguageException e) {
            throw new MLContextException("Exception occurred while rewriting persistent reads and writes", e);
        } catch (HopsException e) {
            throw new MLContextException("Exception occurred while rewriting persistent reads and writes", e);
        }
    }
}
Also used : LanguageException(org.apache.sysml.parser.LanguageException) ProgramRewriter(org.apache.sysml.hops.rewrite.ProgramRewriter) LocalVariableMap(org.apache.sysml.runtime.controlprogram.LocalVariableMap) HopsException(org.apache.sysml.hops.HopsException) RewriteRemovePersistentReadWrite(org.apache.sysml.hops.rewrite.RewriteRemovePersistentReadWrite)

Example 3 with ProgramRewriter

use of org.apache.sysml.hops.rewrite.ProgramRewriter in project incubator-systemml by apache.

the class OptimizerRuleBased method rewriteInjectSparkLoopCheckpointing.

///////
//REWRITE inject spark loop checkpointing
///
protected void rewriteInjectSparkLoopCheckpointing(OptNode n) throws DMLRuntimeException {
    //get program blocks of root parfor
    Object[] progobj = OptTreeConverter.getAbstractPlanMapping().getMappedProg(n.getID());
    ParForStatementBlock pfsb = (ParForStatementBlock) progobj[0];
    ParForStatement fs = (ParForStatement) pfsb.getStatement(0);
    ParForProgramBlock pfpb = (ParForProgramBlock) progobj[1];
    boolean applied = false;
    try {
        //apply hop rewrite inject spark checkpoints (but without context awareness)
        RewriteInjectSparkLoopCheckpointing rewrite = new RewriteInjectSparkLoopCheckpointing(false);
        ProgramRewriter rewriter = new ProgramRewriter(rewrite);
        ProgramRewriteStatus state = new ProgramRewriteStatus();
        rewriter.rewriteStatementBlockHopDAGs(pfsb, state);
        fs.setBody(rewriter.rewriteStatementBlocks(fs.getBody(), state));
        //recompile if additional checkpoints introduced
        if (state.getInjectedCheckpoints()) {
            pfpb.setChildBlocks(ProgramRecompiler.generatePartitialRuntimeProgram(pfpb.getProgram(), fs.getBody()));
            applied = true;
        }
    } catch (Exception ex) {
        throw new DMLRuntimeException(ex);
    }
    LOG.debug(getOptMode() + " OPT: rewrite 'inject spark loop checkpointing' - result=" + applied);
}
Also used : ProgramRewriter(org.apache.sysml.hops.rewrite.ProgramRewriter) ParForStatementBlock(org.apache.sysml.parser.ParForStatementBlock) MatrixObject(org.apache.sysml.runtime.controlprogram.caching.MatrixObject) RDDObject(org.apache.sysml.runtime.instructions.spark.data.RDDObject) ParForStatement(org.apache.sysml.parser.ParForStatement) ProgramRewriteStatus(org.apache.sysml.hops.rewrite.ProgramRewriteStatus) RewriteInjectSparkLoopCheckpointing(org.apache.sysml.hops.rewrite.RewriteInjectSparkLoopCheckpointing) HopsException(org.apache.sysml.hops.HopsException) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) LopsException(org.apache.sysml.lops.LopsException) LanguageException(org.apache.sysml.parser.LanguageException) IOException(java.io.IOException) ParForProgramBlock(org.apache.sysml.runtime.controlprogram.ParForProgramBlock) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException)

Example 4 with ProgramRewriter

use of org.apache.sysml.hops.rewrite.ProgramRewriter in project incubator-systemml by apache.

the class OptimizationWrapper method createProgramRewriterWithRuleSets.

private static ProgramRewriter createProgramRewriterWithRuleSets() {
    //create hop rewrite set
    ArrayList<HopRewriteRule> hRewrites = new ArrayList<HopRewriteRule>();
    hRewrites.add(new RewriteConstantFolding());
    //create statementblock rewrite set
    ArrayList<StatementBlockRewriteRule> sbRewrites = new ArrayList<StatementBlockRewriteRule>();
    sbRewrites.add(new RewriteRemoveUnnecessaryBranches());
    ProgramRewriter rewriter = new ProgramRewriter(hRewrites, sbRewrites);
    return rewriter;
}
Also used : ProgramRewriter(org.apache.sysml.hops.rewrite.ProgramRewriter) StatementBlockRewriteRule(org.apache.sysml.hops.rewrite.StatementBlockRewriteRule) ArrayList(java.util.ArrayList) RewriteRemoveUnnecessaryBranches(org.apache.sysml.hops.rewrite.RewriteRemoveUnnecessaryBranches) RewriteConstantFolding(org.apache.sysml.hops.rewrite.RewriteConstantFolding) HopRewriteRule(org.apache.sysml.hops.rewrite.HopRewriteRule)

Example 5 with ProgramRewriter

use of org.apache.sysml.hops.rewrite.ProgramRewriter in project incubator-systemml by apache.

the class OptimizationWrapper method optimize.

@SuppressWarnings("unused")
private static void optimize(POptMode otype, int ck, double cm, ParForStatementBlock sb, ParForProgramBlock pb, ExecutionContext ec, boolean monitor) throws DMLRuntimeException {
    Timing time = new Timing(true);
    //maintain statistics
    if (DMLScript.STATISTICS)
        Statistics.incrementParForOptimCount();
    //create specified optimizer
    Optimizer opt = createOptimizer(otype);
    CostModelType cmtype = opt.getCostModelType();
    LOG.trace("ParFOR Opt: Created optimizer (" + otype + "," + opt.getPlanInputType() + "," + opt.getCostModelType());
    OptTree tree = null;
    //recompile parfor body 
    if (ConfigurationManager.isDynamicRecompilation()) {
        ForStatement fs = (ForStatement) sb.getStatement(0);
        //debug output before recompilation
        if (LOG.isDebugEnabled()) {
            try {
                tree = OptTreeConverter.createOptTree(ck, cm, opt.getPlanInputType(), sb, pb, ec);
                LOG.debug("ParFOR Opt: Input plan (before recompilation):\n" + tree.explain(false));
                OptTreeConverter.clear();
            } catch (Exception ex) {
                throw new DMLRuntimeException("Unable to create opt tree.", ex);
            }
        }
        //separate propagation required because recompile in-place without literal replacement)
        try {
            LocalVariableMap constVars = ProgramRecompiler.getReusableScalarVariables(sb.getDMLProg(), sb, ec.getVariables());
            ProgramRecompiler.replaceConstantScalarVariables(sb, constVars);
        } catch (Exception ex) {
            throw new DMLRuntimeException(ex);
        }
        //program rewrites (e.g., constant folding, branch removal) according to replaced literals
        try {
            ProgramRewriter rewriter = createProgramRewriterWithRuleSets();
            ProgramRewriteStatus state = new ProgramRewriteStatus();
            rewriter.rewriteStatementBlockHopDAGs(sb, state);
            fs.setBody(rewriter.rewriteStatementBlocks(fs.getBody(), state));
            if (state.getRemovedBranches()) {
                LOG.debug("ParFOR Opt: Removed branches during program rewrites, rebuilding runtime program");
                pb.setChildBlocks(ProgramRecompiler.generatePartitialRuntimeProgram(pb.getProgram(), fs.getBody()));
            }
        } catch (Exception ex) {
            throw new DMLRuntimeException(ex);
        }
        //recompilation of parfor body and called functions (if safe)
        try {
            //core parfor body recompilation (based on symbol table entries)
            //* clone of variables in order to allow for statistics propagation across DAGs
            //(tid=0, because deep copies created after opt)
            LocalVariableMap tmp = (LocalVariableMap) ec.getVariables().clone();
            Recompiler.recompileProgramBlockHierarchy(pb.getChildBlocks(), tmp, 0, true);
            //inter-procedural optimization (based on previous recompilation)
            if (pb.hasFunctions()) {
                InterProceduralAnalysis ipa = new InterProceduralAnalysis();
                Set<String> fcand = ipa.analyzeSubProgram(sb);
                if (!fcand.isEmpty()) {
                    //regenerate runtime program of modified functions
                    for (String func : fcand) {
                        String[] funcparts = DMLProgram.splitFunctionKey(func);
                        FunctionProgramBlock fpb = pb.getProgram().getFunctionProgramBlock(funcparts[0], funcparts[1]);
                        //reset recompilation flags according to recompileOnce because it is only safe if function is recompileOnce 
                        //because then recompiled for every execution (otherwise potential issues if func also called outside parfor)
                        Recompiler.recompileProgramBlockHierarchy(fpb.getChildBlocks(), new LocalVariableMap(), 0, fpb.isRecompileOnce());
                    }
                }
            }
        } catch (Exception ex) {
            throw new DMLRuntimeException(ex);
        }
    }
    //create opt tree (before optimization)
    try {
        tree = OptTreeConverter.createOptTree(ck, cm, opt.getPlanInputType(), sb, pb, ec);
        LOG.debug("ParFOR Opt: Input plan (before optimization):\n" + tree.explain(false));
    } catch (Exception ex) {
        throw new DMLRuntimeException("Unable to create opt tree.", ex);
    }
    //create cost estimator
    CostEstimator est = createCostEstimator(cmtype, ec.getVariables());
    LOG.trace("ParFOR Opt: Created cost estimator (" + cmtype + ")");
    //core optimize
    opt.optimize(sb, pb, tree, est, ec);
    LOG.debug("ParFOR Opt: Optimized plan (after optimization): \n" + tree.explain(false));
    //assert plan correctness
    if (CHECK_PLAN_CORRECTNESS && LOG.isDebugEnabled()) {
        try {
            OptTreePlanChecker.checkProgramCorrectness(pb, sb, new HashSet<String>());
            LOG.debug("ParFOR Opt: Checked plan and program correctness.");
        } catch (Exception ex) {
            throw new DMLRuntimeException("Failed to check program correctness.", ex);
        }
    }
    long ltime = (long) time.stop();
    LOG.trace("ParFOR Opt: Optimized plan in " + ltime + "ms.");
    if (DMLScript.STATISTICS)
        Statistics.incrementParForOptimTime(ltime);
    //cleanup phase
    OptTreeConverter.clear();
    //monitor stats
    if (monitor) {
        StatisticMonitor.putPFStat(pb.getID(), Stat.OPT_OPTIMIZER, otype.ordinal());
        StatisticMonitor.putPFStat(pb.getID(), Stat.OPT_NUMTPLANS, opt.getNumTotalPlans());
        StatisticMonitor.putPFStat(pb.getID(), Stat.OPT_NUMEPLANS, opt.getNumEvaluatedPlans());
    }
}
Also used : FunctionProgramBlock(org.apache.sysml.runtime.controlprogram.FunctionProgramBlock) ProgramRewriter(org.apache.sysml.hops.rewrite.ProgramRewriter) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) LocalVariableMap(org.apache.sysml.runtime.controlprogram.LocalVariableMap) CostModelType(org.apache.sysml.runtime.controlprogram.parfor.opt.Optimizer.CostModelType) InterProceduralAnalysis(org.apache.sysml.hops.ipa.InterProceduralAnalysis) Timing(org.apache.sysml.runtime.controlprogram.parfor.stat.Timing) ForStatement(org.apache.sysml.parser.ForStatement) ProgramRewriteStatus(org.apache.sysml.hops.rewrite.ProgramRewriteStatus)

Aggregations

ProgramRewriter (org.apache.sysml.hops.rewrite.ProgramRewriter)6 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)3 IOException (java.io.IOException)2 HopsException (org.apache.sysml.hops.HopsException)2 InterProceduralAnalysis (org.apache.sysml.hops.ipa.InterProceduralAnalysis)2 ProgramRewriteStatus (org.apache.sysml.hops.rewrite.ProgramRewriteStatus)2 RewriteRemovePersistentReadWrite (org.apache.sysml.hops.rewrite.RewriteRemovePersistentReadWrite)2 LanguageException (org.apache.sysml.parser.LanguageException)2 LocalVariableMap (org.apache.sysml.runtime.controlprogram.LocalVariableMap)2 ArrayList (java.util.ArrayList)1 DMLException (org.apache.sysml.api.DMLException)1 HopRewriteRule (org.apache.sysml.hops.rewrite.HopRewriteRule)1 RewriteConstantFolding (org.apache.sysml.hops.rewrite.RewriteConstantFolding)1 RewriteInjectSparkLoopCheckpointing (org.apache.sysml.hops.rewrite.RewriteInjectSparkLoopCheckpointing)1 RewriteRemoveUnnecessaryBranches (org.apache.sysml.hops.rewrite.RewriteRemoveUnnecessaryBranches)1 StatementBlockRewriteRule (org.apache.sysml.hops.rewrite.StatementBlockRewriteRule)1 LopsException (org.apache.sysml.lops.LopsException)1 DMLProgram (org.apache.sysml.parser.DMLProgram)1 DMLTranslator (org.apache.sysml.parser.DMLTranslator)1 ForStatement (org.apache.sysml.parser.ForStatement)1