Search in sources :

Example 26 with ForStatementBlock

use of org.apache.sysml.parser.ForStatementBlock in project incubator-systemml by apache.

the class SpoofCompiler method generateCodeFromStatementBlock.

public static void generateCodeFromStatementBlock(StatementBlock current) throws HopsException, DMLRuntimeException {
    if (current instanceof FunctionStatementBlock) {
        FunctionStatementBlock fsb = (FunctionStatementBlock) current;
        FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
        for (StatementBlock sb : fstmt.getBody()) generateCodeFromStatementBlock(sb);
    } else if (current instanceof WhileStatementBlock) {
        WhileStatementBlock wsb = (WhileStatementBlock) current;
        WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
        wsb.setPredicateHops(optimize(wsb.getPredicateHops(), false));
        for (StatementBlock sb : wstmt.getBody()) generateCodeFromStatementBlock(sb);
    } else if (current instanceof IfStatementBlock) {
        IfStatementBlock isb = (IfStatementBlock) current;
        IfStatement istmt = (IfStatement) isb.getStatement(0);
        isb.setPredicateHops(optimize(isb.getPredicateHops(), false));
        for (StatementBlock sb : istmt.getIfBody()) generateCodeFromStatementBlock(sb);
        for (StatementBlock sb : istmt.getElseBody()) generateCodeFromStatementBlock(sb);
    } else if (//incl parfor
    current instanceof ForStatementBlock) {
        ForStatementBlock fsb = (ForStatementBlock) current;
        ForStatement fstmt = (ForStatement) fsb.getStatement(0);
        fsb.setFromHops(optimize(fsb.getFromHops(), false));
        fsb.setToHops(optimize(fsb.getToHops(), false));
        fsb.setIncrementHops(optimize(fsb.getIncrementHops(), false));
        for (StatementBlock sb : fstmt.getBody()) generateCodeFromStatementBlock(sb);
    } else //generic (last-level)
    {
        current.set_hops(generateCodeFromHopDAGs(current.get_hops()));
        current.updateRecompilationFlag();
    }
}
Also used : ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) FunctionStatement(org.apache.sysml.parser.FunctionStatement) IfStatement(org.apache.sysml.parser.IfStatement) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) WhileStatement(org.apache.sysml.parser.WhileStatement) ForStatement(org.apache.sysml.parser.ForStatement) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock)

Example 27 with ForStatementBlock

use of org.apache.sysml.parser.ForStatementBlock in project incubator-systemml by apache.

the class GraphBuilder method constructGDFGraph.

@SuppressWarnings("unchecked")
private static void constructGDFGraph(ProgramBlock pb, HashMap<String, GDFNode> roots) throws DMLRuntimeException, HopsException {
    if (pb instanceof FunctionProgramBlock) {
        throw new DMLRuntimeException("FunctionProgramBlocks not implemented yet.");
    } else if (pb instanceof WhileProgramBlock) {
        WhileProgramBlock wpb = (WhileProgramBlock) pb;
        WhileStatementBlock wsb = (WhileStatementBlock) pb.getStatementBlock();
        //construct predicate node (conceptually sequence of from/to/incr)
        GDFNode pred = constructGDFGraph(wsb.getPredicateHops(), wpb, new HashMap<Long, GDFNode>(), roots);
        HashMap<String, GDFNode> inputs = constructLoopInputNodes(wpb, wsb, roots);
        HashMap<String, GDFNode> lroots = (HashMap<String, GDFNode>) inputs.clone();
        //process childs blocks
        for (ProgramBlock pbc : wpb.getChildBlocks()) constructGDFGraph(pbc, lroots);
        HashMap<String, GDFNode> outputs = constructLoopOutputNodes(wsb, lroots);
        GDFLoopNode lnode = new GDFLoopNode(wpb, pred, inputs, outputs);
        //construct crossblock nodes
        constructLoopOutputCrossBlockNodes(wsb, lnode, outputs, roots, wpb);
    } else if (pb instanceof IfProgramBlock) {
        IfProgramBlock ipb = (IfProgramBlock) pb;
        IfStatementBlock isb = (IfStatementBlock) pb.getStatementBlock();
        //construct predicate
        if (isb.getPredicateHops() != null) {
            Hop pred = isb.getPredicateHops();
            roots.put(pred.getName(), constructGDFGraph(pred, ipb, new HashMap<Long, GDFNode>(), roots));
        }
        //construct if and else branch separately
        HashMap<String, GDFNode> ifRoots = (HashMap<String, GDFNode>) roots.clone();
        HashMap<String, GDFNode> elseRoots = (HashMap<String, GDFNode>) roots.clone();
        for (ProgramBlock pbc : ipb.getChildBlocksIfBody()) constructGDFGraph(pbc, ifRoots);
        if (ipb.getChildBlocksElseBody() != null)
            for (ProgramBlock pbc : ipb.getChildBlocksElseBody()) constructGDFGraph(pbc, elseRoots);
        //merge data flow roots (if no else, elseRoots refer to original roots)
        reconcileMergeIfProgramBlockOutputs(ifRoots, elseRoots, roots, ipb);
    } else if (//incl parfor
    pb instanceof ForProgramBlock) {
        ForProgramBlock fpb = (ForProgramBlock) pb;
        ForStatementBlock fsb = (ForStatementBlock) pb.getStatementBlock();
        //construct predicate node (conceptually sequence of from/to/incr)
        GDFNode pred = constructForPredicateNode(fpb, fsb, roots);
        HashMap<String, GDFNode> inputs = constructLoopInputNodes(fpb, fsb, roots);
        HashMap<String, GDFNode> lroots = (HashMap<String, GDFNode>) inputs.clone();
        //process childs blocks
        for (ProgramBlock pbc : fpb.getChildBlocks()) constructGDFGraph(pbc, lroots);
        HashMap<String, GDFNode> outputs = constructLoopOutputNodes(fsb, lroots);
        GDFLoopNode lnode = new GDFLoopNode(fpb, pred, inputs, outputs);
        //construct crossblock nodes
        constructLoopOutputCrossBlockNodes(fsb, lnode, outputs, roots, fpb);
    } else //last-level program block
    {
        StatementBlock sb = pb.getStatementBlock();
        ArrayList<Hop> hops = sb.get_hops();
        if (hops != null) {
            //create new local memo structure for local dag
            HashMap<Long, GDFNode> lmemo = new HashMap<Long, GDFNode>();
            for (Hop hop : hops) {
                //recursively construct GDF graph for hop dag root
                GDFNode root = constructGDFGraph(hop, pb, lmemo, roots);
                if (root == null)
                    throw new HopsException("GDFGraphBuilder: failed to constuct dag root for: " + Explain.explain(hop));
                //create cross block nodes for all transient writes
                if (hop instanceof DataOp && ((DataOp) hop).getDataOpType() == DataOpTypes.TRANSIENTWRITE)
                    root = new GDFCrossBlockNode(hop, pb, root, hop.getName());
                //add GDF root node to global roots 
                roots.put(hop.getName(), root);
            }
        }
    }
}
Also used : FunctionProgramBlock(org.apache.sysml.runtime.controlprogram.FunctionProgramBlock) IfProgramBlock(org.apache.sysml.runtime.controlprogram.IfProgramBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) HashMap(java.util.HashMap) ForProgramBlock(org.apache.sysml.runtime.controlprogram.ForProgramBlock) Hop(org.apache.sysml.hops.Hop) ArrayList(java.util.ArrayList) WhileProgramBlock(org.apache.sysml.runtime.controlprogram.WhileProgramBlock) HopsException(org.apache.sysml.hops.HopsException) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) ForProgramBlock(org.apache.sysml.runtime.controlprogram.ForProgramBlock) IfProgramBlock(org.apache.sysml.runtime.controlprogram.IfProgramBlock) FunctionProgramBlock(org.apache.sysml.runtime.controlprogram.FunctionProgramBlock) ProgramBlock(org.apache.sysml.runtime.controlprogram.ProgramBlock) WhileProgramBlock(org.apache.sysml.runtime.controlprogram.WhileProgramBlock) DataOp(org.apache.sysml.hops.DataOp) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock)

Example 28 with ForStatementBlock

use of org.apache.sysml.parser.ForStatementBlock in project incubator-systemml by apache.

the class InterProceduralAnalysis method rFlagFunctionForRecompileOnce.

/**
	 * Returns true if this statementblock requires recompilation inside a 
	 * loop statement block.
	 * 
	 * @param sb statement block
	 * @param inLoop true if in loop
	 * @return true if statement block requires recompilation inside a loop statement block
	 */
public boolean rFlagFunctionForRecompileOnce(StatementBlock sb, boolean inLoop) {
    boolean ret = false;
    if (sb instanceof FunctionStatementBlock) {
        FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
        FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
        for (StatementBlock c : fstmt.getBody()) ret |= rFlagFunctionForRecompileOnce(c, inLoop);
    } else if (sb instanceof WhileStatementBlock) {
        //recompilation information not available at this point
        ret = true;
    /*
			WhileStatementBlock wsb = (WhileStatementBlock) sb;
			WhileStatement wstmt = (WhileStatement)wsb.getStatement(0);
			ret |= (inLoop && wsb.requiresPredicateRecompilation() );
			for( StatementBlock c : wstmt.getBody() )
				ret |= rFlagFunctionForRecompileOnce( c, true );
			*/
    } else if (sb instanceof IfStatementBlock) {
        IfStatementBlock isb = (IfStatementBlock) sb;
        IfStatement istmt = (IfStatement) isb.getStatement(0);
        ret |= (inLoop && isb.requiresPredicateRecompilation());
        for (StatementBlock c : istmt.getIfBody()) ret |= rFlagFunctionForRecompileOnce(c, inLoop);
        for (StatementBlock c : istmt.getElseBody()) ret |= rFlagFunctionForRecompileOnce(c, inLoop);
    } else if (sb instanceof ForStatementBlock) {
        //recompilation information not available at this point
        ret = true;
    /* 
			ForStatementBlock fsb = (ForStatementBlock) sb;
			ForStatement fstmt = (ForStatement)fsb.getStatement(0);
			for( StatementBlock c : fstmt.getBody() )
				ret |= rFlagFunctionForRecompileOnce( c, true );
			*/
    } else {
        ret |= (inLoop && sb.requiresRecompilation());
    }
    return ret;
}
Also used : ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) ExternalFunctionStatement(org.apache.sysml.parser.ExternalFunctionStatement) FunctionStatement(org.apache.sysml.parser.FunctionStatement) IfStatement(org.apache.sysml.parser.IfStatement) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock)

Example 29 with ForStatementBlock

use of org.apache.sysml.parser.ForStatementBlock in project incubator-systemml by apache.

the class InterProceduralAnalysis method moveCheckpointAfterUpdate.

private void moveCheckpointAfterUpdate(DMLProgram dmlp) throws HopsException {
    //approach: scan over top-level program (guaranteed to be unconditional),
    //collect checkpoints; determine if used before update; move first checkpoint
    //after update if not used before update (best effort move which often avoids
    //the second checkpoint on loops even though used in between)
    HashMap<String, Hop> chkpointCand = new HashMap<String, Hop>();
    for (StatementBlock sb : dmlp.getStatementBlocks()) {
        //prune candidates (used before updated)
        Set<String> cands = new HashSet<String>(chkpointCand.keySet());
        for (String cand : cands) if (sb.variablesRead().containsVariable(cand) && !sb.variablesUpdated().containsVariable(cand)) {
            //note: variableRead might include false positives due to meta 
            //data operations like nrow(X) or operations removed by rewrites 
            //double check hops on basic blocks; otherwise worst-case
            boolean skipRemove = false;
            if (sb.get_hops() != null) {
                Hop.resetVisitStatus(sb.get_hops());
                skipRemove = true;
                for (Hop root : sb.get_hops()) skipRemove &= !HopRewriteUtils.rContainsRead(root, cand, false);
            }
            if (!skipRemove)
                chkpointCand.remove(cand);
        }
        //prune candidates (updated in conditional control flow)
        Set<String> cands2 = new HashSet<String>(chkpointCand.keySet());
        if (sb instanceof IfStatementBlock || sb instanceof WhileStatementBlock || sb instanceof ForStatementBlock) {
            for (String cand : cands2) if (sb.variablesUpdated().containsVariable(cand)) {
                chkpointCand.remove(cand);
            }
        } else //move checkpoint after update with simple read chain 
        //(note: right now this only applies if the checkpoints comes from a previous
        //statement block, within-dag checkpoints should be handled during injection)
        {
            for (String cand : cands2) if (sb.variablesUpdated().containsVariable(cand) && sb.get_hops() != null) {
                Hop.resetVisitStatus(sb.get_hops());
                for (Hop root : sb.get_hops()) if (root.getName().equals(cand)) {
                    if (HopRewriteUtils.rHasSimpleReadChain(root, cand)) {
                        chkpointCand.get(cand).setRequiresCheckpoint(false);
                        root.getInput().get(0).setRequiresCheckpoint(true);
                        chkpointCand.put(cand, root.getInput().get(0));
                    } else
                        chkpointCand.remove(cand);
                }
            }
        }
        //collect checkpoints
        ArrayList<Hop> tmp = collectCheckpoints(sb.get_hops());
        for (Hop chkpoint : tmp) {
            chkpointCand.put(chkpoint.getName(), chkpoint);
        }
    }
}
Also used : ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) HashMap(java.util.HashMap) Hop(org.apache.sysml.hops.Hop) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) HashSet(java.util.HashSet) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock)

Example 30 with ForStatementBlock

use of org.apache.sysml.parser.ForStatementBlock in project incubator-systemml by apache.

the class InterProceduralAnalysis method propagateStatisticsAcrossBlock.

/////////////////////////////
// INTRA-PROCEDURE ANALYSIS
//////	
/**
	 * Perform intra-procedural analysis (IPA) by propagating statistics
	 * across statement blocks.
	 *
	 * @param sb  DML statement blocks.
	 * @param fcand  Function candidates.
	 * @param callVars  Map of variables eligible for propagation.
	 * @param fcandSafeNNZ  Function candidate safe non-zeros.
	 * @param unaryFcands  Unary function candidates.
	 * @param fnStack  Function stack to determine current scope.
	 * @throws HopsException  If a HopsException occurs.
	 * @throws ParseException  If a ParseException occurs.
	 */
private void propagateStatisticsAcrossBlock(StatementBlock sb, Map<String, Integer> fcand, LocalVariableMap callVars, Map<String, Set<Long>> fcandSafeNNZ, Set<String> unaryFcands, Set<String> fnStack) throws HopsException, ParseException {
    if (sb instanceof FunctionStatementBlock) {
        FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
        FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
        for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack);
    } else if (sb instanceof WhileStatementBlock) {
        WhileStatementBlock wsb = (WhileStatementBlock) sb;
        WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
        //old stats into predicate
        propagateStatisticsAcrossPredicateDAG(wsb.getPredicateHops(), callVars);
        //remove updated constant scalars
        Recompiler.removeUpdatedScalars(callVars, wsb);
        //check and propagate stats into body
        LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone();
        for (StatementBlock sbi : wstmt.getBody()) propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack);
        if (Recompiler.reconcileUpdatedCallVarsLoops(oldCallVars, callVars, wsb)) {
            //second pass if required
            propagateStatisticsAcrossPredicateDAG(wsb.getPredicateHops(), callVars);
            for (StatementBlock sbi : wstmt.getBody()) propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack);
        }
        //remove updated constant scalars
        Recompiler.removeUpdatedScalars(callVars, sb);
    } else if (sb instanceof IfStatementBlock) {
        IfStatementBlock isb = (IfStatementBlock) sb;
        IfStatement istmt = (IfStatement) isb.getStatement(0);
        //old stats into predicate
        propagateStatisticsAcrossPredicateDAG(isb.getPredicateHops(), callVars);
        //check and propagate stats into body
        LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone();
        LocalVariableMap callVarsElse = (LocalVariableMap) callVars.clone();
        for (StatementBlock sbi : istmt.getIfBody()) propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack);
        for (StatementBlock sbi : istmt.getElseBody()) propagateStatisticsAcrossBlock(sbi, fcand, callVarsElse, fcandSafeNNZ, unaryFcands, fnStack);
        callVars = Recompiler.reconcileUpdatedCallVarsIf(oldCallVars, callVars, callVarsElse, isb);
        //remove updated constant scalars
        Recompiler.removeUpdatedScalars(callVars, sb);
    } else if (//incl parfor
    sb instanceof ForStatementBlock) {
        ForStatementBlock fsb = (ForStatementBlock) sb;
        ForStatement fstmt = (ForStatement) fsb.getStatement(0);
        //old stats into predicate
        propagateStatisticsAcrossPredicateDAG(fsb.getFromHops(), callVars);
        propagateStatisticsAcrossPredicateDAG(fsb.getToHops(), callVars);
        propagateStatisticsAcrossPredicateDAG(fsb.getIncrementHops(), callVars);
        //remove updated constant scalars
        Recompiler.removeUpdatedScalars(callVars, fsb);
        //check and propagate stats into body
        LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone();
        for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack);
        if (Recompiler.reconcileUpdatedCallVarsLoops(oldCallVars, callVars, fsb))
            for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack);
        //remove updated constant scalars
        Recompiler.removeUpdatedScalars(callVars, sb);
    } else //generic (last-level)
    {
        //remove updated constant scalars
        Recompiler.removeUpdatedScalars(callVars, sb);
        //old stats in, new stats out if updated
        ArrayList<Hop> roots = sb.get_hops();
        DMLProgram prog = sb.getDMLProg();
        //replace scalar reads with literals
        Hop.resetVisitStatus(roots);
        propagateScalarsAcrossDAG(roots, callVars);
        //refresh stats across dag
        Hop.resetVisitStatus(roots);
        propagateStatisticsAcrossDAG(roots, callVars);
        //propagate stats into function calls
        Hop.resetVisitStatus(roots);
        propagateStatisticsIntoFunctions(prog, roots, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack);
    }
}
Also used : ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) ExternalFunctionStatement(org.apache.sysml.parser.ExternalFunctionStatement) FunctionStatement(org.apache.sysml.parser.FunctionStatement) IfStatement(org.apache.sysml.parser.IfStatement) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) LocalVariableMap(org.apache.sysml.runtime.controlprogram.LocalVariableMap) ArrayList(java.util.ArrayList) DMLProgram(org.apache.sysml.parser.DMLProgram) WhileStatement(org.apache.sysml.parser.WhileStatement) ForStatement(org.apache.sysml.parser.ForStatement) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock)

Aggregations

ForStatementBlock (org.apache.sysml.parser.ForStatementBlock)31 WhileStatementBlock (org.apache.sysml.parser.WhileStatementBlock)27 IfStatementBlock (org.apache.sysml.parser.IfStatementBlock)26 StatementBlock (org.apache.sysml.parser.StatementBlock)26 ForStatement (org.apache.sysml.parser.ForStatement)14 FunctionStatementBlock (org.apache.sysml.parser.FunctionStatementBlock)13 IfStatement (org.apache.sysml.parser.IfStatement)13 WhileStatement (org.apache.sysml.parser.WhileStatement)13 ForProgramBlock (org.apache.sysml.runtime.controlprogram.ForProgramBlock)13 ArrayList (java.util.ArrayList)12 Hop (org.apache.sysml.hops.Hop)12 IfProgramBlock (org.apache.sysml.runtime.controlprogram.IfProgramBlock)12 WhileProgramBlock (org.apache.sysml.runtime.controlprogram.WhileProgramBlock)12 FunctionStatement (org.apache.sysml.parser.FunctionStatement)8 FunctionProgramBlock (org.apache.sysml.runtime.controlprogram.FunctionProgramBlock)8 ParForStatementBlock (org.apache.sysml.parser.ParForStatementBlock)7 ProgramBlock (org.apache.sysml.runtime.controlprogram.ProgramBlock)7 ParForProgramBlock (org.apache.sysml.runtime.controlprogram.ParForProgramBlock)6 LocalVariableMap (org.apache.sysml.runtime.controlprogram.LocalVariableMap)5 ExternalFunctionStatement (org.apache.sysml.parser.ExternalFunctionStatement)4