Search in sources :

Example 11 with FunctionStatement

use of org.apache.sysml.parser.FunctionStatement in project incubator-systemml by apache.

the class SpoofCompiler method generateCodeFromStatementBlock.

public static void generateCodeFromStatementBlock(StatementBlock current) {
    if (current instanceof FunctionStatementBlock) {
        FunctionStatementBlock fsb = (FunctionStatementBlock) current;
        FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
        for (StatementBlock sb : fstmt.getBody()) generateCodeFromStatementBlock(sb);
    } else if (current instanceof WhileStatementBlock) {
        WhileStatementBlock wsb = (WhileStatementBlock) current;
        WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
        wsb.setPredicateHops(optimize(wsb.getPredicateHops(), false));
        for (StatementBlock sb : wstmt.getBody()) generateCodeFromStatementBlock(sb);
    } else if (current instanceof IfStatementBlock) {
        IfStatementBlock isb = (IfStatementBlock) current;
        IfStatement istmt = (IfStatement) isb.getStatement(0);
        isb.setPredicateHops(optimize(isb.getPredicateHops(), false));
        for (StatementBlock sb : istmt.getIfBody()) generateCodeFromStatementBlock(sb);
        for (StatementBlock sb : istmt.getElseBody()) generateCodeFromStatementBlock(sb);
    } else if (// incl parfor
    current instanceof ForStatementBlock) {
        ForStatementBlock fsb = (ForStatementBlock) current;
        ForStatement fstmt = (ForStatement) fsb.getStatement(0);
        fsb.setFromHops(optimize(fsb.getFromHops(), false));
        fsb.setToHops(optimize(fsb.getToHops(), false));
        fsb.setIncrementHops(optimize(fsb.getIncrementHops(), false));
        for (StatementBlock sb : fstmt.getBody()) generateCodeFromStatementBlock(sb);
    } else // generic (last-level)
    {
        current.setHops(generateCodeFromHopDAGs(current.getHops()));
        current.updateRecompilationFlag();
    }
}
Also used : ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) FunctionStatement(org.apache.sysml.parser.FunctionStatement) IfStatement(org.apache.sysml.parser.IfStatement) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) WhileStatement(org.apache.sysml.parser.WhileStatement) ForStatement(org.apache.sysml.parser.ForStatement) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock)

Example 12 with FunctionStatement

use of org.apache.sysml.parser.FunctionStatement in project incubator-systemml by apache.

the class RewriteCompressedReblock method rAnalyzeProgram.

private static void rAnalyzeProgram(StatementBlock sb, ProbeStatus status) {
    if (sb instanceof FunctionStatementBlock) {
        FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
        FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
        for (StatementBlock csb : fstmt.getBody()) rAnalyzeProgram(csb, status);
    } else if (sb instanceof WhileStatementBlock) {
        WhileStatementBlock wsb = (WhileStatementBlock) sb;
        WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
        for (StatementBlock csb : wstmt.getBody()) rAnalyzeProgram(csb, status);
        if (wsb.variablesRead().containsAnyName(status.compMtx))
            status.usedInLoop = true;
    } else if (sb instanceof IfStatementBlock) {
        IfStatementBlock isb = (IfStatementBlock) sb;
        IfStatement istmt = (IfStatement) isb.getStatement(0);
        for (StatementBlock csb : istmt.getIfBody()) rAnalyzeProgram(csb, status);
        for (StatementBlock csb : istmt.getElseBody()) rAnalyzeProgram(csb, status);
        if (isb.variablesUpdated().containsAnyName(status.compMtx))
            status.condUpdate = true;
    } else if (sb instanceof ForStatementBlock) {
        // incl parfor
        ForStatementBlock fsb = (ForStatementBlock) sb;
        ForStatement fstmt = (ForStatement) fsb.getStatement(0);
        for (StatementBlock csb : fstmt.getBody()) rAnalyzeProgram(csb, status);
        if (fsb.variablesRead().containsAnyName(status.compMtx))
            status.usedInLoop = true;
    } else if (sb.getHops() != null) {
        // generic (last-level)
        ArrayList<Hop> roots = sb.getHops();
        Hop.resetVisitStatus(roots);
        // process entire HOP DAG starting from the roots
        for (Hop root : roots) rAnalyzeHopDag(root, status);
        // remove temporary variables
        status.compMtx.removeIf(n -> n.startsWith(TMP_PREFIX));
        Hop.resetVisitStatus(roots);
    }
}
Also used : ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) FunctionStatement(org.apache.sysml.parser.FunctionStatement) IfStatement(org.apache.sysml.parser.IfStatement) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) Hop(org.apache.sysml.hops.Hop) WhileStatement(org.apache.sysml.parser.WhileStatement) ForStatement(org.apache.sysml.parser.ForStatement) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock)

Example 13 with FunctionStatement

use of org.apache.sysml.parser.FunctionStatement in project incubator-systemml by apache.

the class RewriteCompressedReblock method rAnalyzeHopDag.

private static void rAnalyzeHopDag(Hop current, ProbeStatus status) {
    if (current.isVisited())
        return;
    // process children recursively
    for (Hop input : current.getInput()) rAnalyzeHopDag(input, status);
    // handle source persistent read
    if (current.getHopID() == status.startHopID) {
        status.compMtx.add(getTmpName(current));
        status.foundStart = true;
    }
    // a) handle function calls
    if (current instanceof FunctionOp && hasCompressedInput(current, status)) {
        // TODO handle of functions in a more fine-grained manner
        // to cover special cases multiple calls where compressed
        // inputs might occur for different input parameters
        FunctionOp fop = (FunctionOp) current;
        String fkey = fop.getFunctionKey();
        if (!status.procFn.contains(fkey)) {
            // memoization to avoid redundant analysis and recursive calls
            status.procFn.add(fkey);
            // map inputs to function inputs
            FunctionStatementBlock fsb = status.prog.getFunctionStatementBlock(fkey);
            FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
            ProbeStatus status2 = new ProbeStatus(status);
            for (int i = 0; i < fop.getInput().size(); i++) if (status.compMtx.contains(getTmpName(fop.getInput().get(i))))
                status2.compMtx.add(fstmt.getInputParams().get(i).getName());
            // analyze function and merge meta info
            rAnalyzeProgram(fsb, status2);
            status.foundStart |= status2.foundStart;
            status.usedInLoop |= status2.usedInLoop;
            status.condUpdate |= status2.condUpdate;
            status.nonApplicable |= status2.nonApplicable;
            // map function outputs to outputs
            String[] outputs = fop.getOutputVariableNames();
            for (int i = 0; i < outputs.length; i++) if (status2.compMtx.contains(fstmt.getOutputParams().get(i).getName()))
                status.compMtx.add(outputs[i]);
        }
    } else // b) handle transient reads and writes (name mapping)
    if (HopRewriteUtils.isData(current, DataOpTypes.TRANSIENTWRITE) && status.compMtx.contains(getTmpName(current.getInput().get(0))))
        status.compMtx.add(current.getName());
    else if (HopRewriteUtils.isData(current, DataOpTypes.TRANSIENTREAD) && status.compMtx.contains(current.getName()))
        status.compMtx.add(getTmpName(current));
    else // c) handle applicable operations
    if (hasCompressedInput(current, status)) {
        // valid with uncompressed outputs
        boolean compUCOut = (// tsmm
        current instanceof AggBinaryOp && current.getDim2() <= current.getColsInBlock() && ((AggBinaryOp) current).checkTransposeSelf() == MMTSJType.LEFT) || // mvmm
        (current instanceof AggBinaryOp && (current.getDim1() == 1 || current.getDim2() == 1)) || (HopRewriteUtils.isTransposeOperation(current) && current.getParent().size() == 1 && current.getParent().get(0) instanceof AggBinaryOp && (current.getParent().get(0).getDim1() == 1 || current.getParent().get(0).getDim2() == 1)) || HopRewriteUtils.isAggUnaryOp(current, AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX);
        // valid with compressed outputs
        boolean compCOut = HopRewriteUtils.isBinaryMatrixScalarOperation(current) || HopRewriteUtils.isBinary(current, OpOp2.CBIND);
        boolean metaOp = HopRewriteUtils.isUnary(current, OpOp1.NROW, OpOp1.NCOL);
        status.nonApplicable |= !(compUCOut || compCOut || metaOp);
        if (compCOut)
            status.compMtx.add(getTmpName(current));
    }
    current.setVisited();
}
Also used : FunctionStatement(org.apache.sysml.parser.FunctionStatement) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) FunctionOp(org.apache.sysml.hops.FunctionOp)

Example 14 with FunctionStatement

use of org.apache.sysml.parser.FunctionStatement in project incubator-systemml by apache.

the class InterProceduralAnalysis method propagateStatisticsAcrossBlock.

// ///////////////////////////
// INTRA-PROCEDURE ANALYSIS
// ////
private void propagateStatisticsAcrossBlock(StatementBlock sb, LocalVariableMap callVars, FunctionCallSizeInfo fcallSizes, Set<String> fnStack) {
    if (sb instanceof FunctionStatementBlock) {
        FunctionStatementBlock fsb = (FunctionStatementBlock) sb;
        FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
        for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
    } else if (sb instanceof WhileStatementBlock) {
        WhileStatementBlock wsb = (WhileStatementBlock) sb;
        WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
        // old stats into predicate
        propagateStatisticsAcrossPredicateDAG(wsb.getPredicateHops(), callVars);
        // remove updated constant scalars
        Recompiler.removeUpdatedScalars(callVars, wsb);
        // check and propagate stats into body
        LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone();
        for (StatementBlock sbi : wstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
        if (Recompiler.reconcileUpdatedCallVarsLoops(oldCallVars, callVars, wsb)) {
            // second pass if required
            propagateStatisticsAcrossPredicateDAG(wsb.getPredicateHops(), callVars);
            for (StatementBlock sbi : wstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
        }
        // remove updated constant scalars
        Recompiler.removeUpdatedScalars(callVars, sb);
    } else if (sb instanceof IfStatementBlock) {
        IfStatementBlock isb = (IfStatementBlock) sb;
        IfStatement istmt = (IfStatement) isb.getStatement(0);
        // old stats into predicate
        propagateStatisticsAcrossPredicateDAG(isb.getPredicateHops(), callVars);
        // check and propagate stats into body
        LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone();
        LocalVariableMap callVarsElse = (LocalVariableMap) callVars.clone();
        for (StatementBlock sbi : istmt.getIfBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
        for (StatementBlock sbi : istmt.getElseBody()) propagateStatisticsAcrossBlock(sbi, callVarsElse, fcallSizes, fnStack);
        callVars = Recompiler.reconcileUpdatedCallVarsIf(oldCallVars, callVars, callVarsElse, isb);
        // remove updated constant scalars
        Recompiler.removeUpdatedScalars(callVars, sb);
    } else if (// incl parfor
    sb instanceof ForStatementBlock) {
        ForStatementBlock fsb = (ForStatementBlock) sb;
        ForStatement fstmt = (ForStatement) fsb.getStatement(0);
        // old stats into predicate
        propagateStatisticsAcrossPredicateDAG(fsb.getFromHops(), callVars);
        propagateStatisticsAcrossPredicateDAG(fsb.getToHops(), callVars);
        propagateStatisticsAcrossPredicateDAG(fsb.getIncrementHops(), callVars);
        // remove updated constant scalars
        Recompiler.removeUpdatedScalars(callVars, fsb);
        // check and propagate stats into body
        LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone();
        for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
        if (Recompiler.reconcileUpdatedCallVarsLoops(oldCallVars, callVars, fsb))
            for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
        // remove updated constant scalars
        Recompiler.removeUpdatedScalars(callVars, sb);
    } else // generic (last-level)
    {
        // remove updated constant scalars
        Recompiler.removeUpdatedScalars(callVars, sb);
        // old stats in, new stats out if updated
        ArrayList<Hop> roots = sb.getHops();
        DMLProgram prog = sb.getDMLProg();
        // replace scalar reads with literals
        Hop.resetVisitStatus(roots);
        propagateScalarsAcrossDAG(roots, callVars);
        // refresh stats across dag
        Hop.resetVisitStatus(roots);
        propagateStatisticsAcrossDAG(roots, callVars);
        // propagate stats into function calls
        Hop.resetVisitStatus(roots);
        propagateStatisticsIntoFunctions(prog, roots, callVars, fcallSizes, fnStack);
    }
}
Also used : ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) ExternalFunctionStatement(org.apache.sysml.parser.ExternalFunctionStatement) FunctionStatement(org.apache.sysml.parser.FunctionStatement) IfStatement(org.apache.sysml.parser.IfStatement) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) LocalVariableMap(org.apache.sysml.runtime.controlprogram.LocalVariableMap) ArrayList(java.util.ArrayList) DMLProgram(org.apache.sysml.parser.DMLProgram) WhileStatement(org.apache.sysml.parser.WhileStatement) ForStatement(org.apache.sysml.parser.ForStatement) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock)

Example 15 with FunctionStatement

use of org.apache.sysml.parser.FunctionStatement in project incubator-systemml by apache.

the class InterProceduralAnalysis method isUnarySizePreservingFunction.

private boolean isUnarySizePreservingFunction(FunctionStatementBlock fsb) {
    FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0);
    // check unary functions over matrices
    boolean ret = (fstmt.getInputParams().size() == 1 && fstmt.getInputParams().get(0).getDataType() == DataType.MATRIX && fstmt.getOutputParams().size() == 1 && fstmt.getOutputParams().get(0).getDataType() == DataType.MATRIX);
    // check size-preserving characteristic
    if (ret) {
        FunctionCallSizeInfo fcallSizes = new FunctionCallSizeInfo(_fgraph, false);
        HashSet<String> fnStack = new HashSet<>();
        LocalVariableMap callVars = new LocalVariableMap();
        // populate input
        MatrixObject mo = createOutputMatrix(7777, 3333, -1);
        callVars.put(fstmt.getInputParams().get(0).getName(), mo);
        // propagate statistics
        for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
        // compare output
        MatrixObject mo2 = (MatrixObject) callVars.get(fstmt.getOutputParams().get(0).getName());
        ret &= mo.getNumRows() == mo2.getNumRows() && mo.getNumColumns() == mo2.getNumColumns();
        // reset function
        mo.getMatrixCharacteristics().setDimension(-1, -1);
        for (StatementBlock sbi : fstmt.getBody()) propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack);
    }
    return ret;
}
Also used : ExternalFunctionStatement(org.apache.sysml.parser.ExternalFunctionStatement) FunctionStatement(org.apache.sysml.parser.FunctionStatement) MatrixObject(org.apache.sysml.runtime.controlprogram.caching.MatrixObject) LocalVariableMap(org.apache.sysml.runtime.controlprogram.LocalVariableMap) FunctionStatementBlock(org.apache.sysml.parser.FunctionStatementBlock) IfStatementBlock(org.apache.sysml.parser.IfStatementBlock) WhileStatementBlock(org.apache.sysml.parser.WhileStatementBlock) ForStatementBlock(org.apache.sysml.parser.ForStatementBlock) StatementBlock(org.apache.sysml.parser.StatementBlock) HashSet(java.util.HashSet)

Aggregations

FunctionStatement (org.apache.sysml.parser.FunctionStatement)26 FunctionStatementBlock (org.apache.sysml.parser.FunctionStatementBlock)24 StatementBlock (org.apache.sysml.parser.StatementBlock)22 ForStatementBlock (org.apache.sysml.parser.ForStatementBlock)19 IfStatementBlock (org.apache.sysml.parser.IfStatementBlock)19 WhileStatementBlock (org.apache.sysml.parser.WhileStatementBlock)19 IfStatement (org.apache.sysml.parser.IfStatement)14 ExternalFunctionStatement (org.apache.sysml.parser.ExternalFunctionStatement)13 ForStatement (org.apache.sysml.parser.ForStatement)12 WhileStatement (org.apache.sysml.parser.WhileStatement)12 ArrayList (java.util.ArrayList)11 Hop (org.apache.sysml.hops.Hop)9 ParForStatementBlock (org.apache.sysml.parser.ParForStatementBlock)9 FunctionOp (org.apache.sysml.hops.FunctionOp)6 LocalVariableMap (org.apache.sysml.runtime.controlprogram.LocalVariableMap)6 DMLProgram (org.apache.sysml.parser.DMLProgram)4 DataIdentifier (org.apache.sysml.parser.DataIdentifier)3 HashSet (java.util.HashSet)2 LiteralOp (org.apache.sysml.hops.LiteralOp)2 ForProgramBlock (org.apache.sysml.runtime.controlprogram.ForProgramBlock)2