Search in sources :

Example 31 with AggUnaryOp

use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyNnzComputation.

private static Hop simplifyNnzComputation(Hop parent, Hop hi, int pos) {
    // sum(ppred(X,0,"!=")) -> literal(nnz(X)), if nnz known
    if (// sum
    hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getOp() == AggOp.SUM && // full aggregate
    ((AggUnaryOp) hi).getDirection() == Direction.RowCol && HopRewriteUtils.isBinary(hi.getInput().get(0), OpOp2.NOTEQUAL)) {
        Hop ppred = hi.getInput().get(0);
        Hop X = null;
        if (ppred.getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) ppred.getInput().get(0)) == 0) {
            X = ppred.getInput().get(1);
        } else if (ppred.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) ppred.getInput().get(1)) == 0) {
            X = ppred.getInput().get(0);
        }
        // apply rewrite if known nnz
        if (X != null && X.getNnz() > 0) {
            Hop hnew = new LiteralOp(X.getNnz());
            HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
            HopRewriteUtils.cleanupUnreferenced(hi);
            hi = hnew;
            LOG.debug("Applied simplifyNnzComputation.");
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp)

Example 32 with AggUnaryOp

use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyRowSumsMVMult.

private static Hop simplifyRowSumsMVMult(Hop parent, Hop hi, int pos) {
    // removed by other rewrite if unnecessary, i.e., if Y==t(Z)
    if (hi instanceof AggUnaryOp) {
        AggUnaryOp uhi = (AggUnaryOp) hi;
        Hop input = uhi.getInput().get(0);
        if (// rowsums
        uhi.getOp() == AggOp.SUM && uhi.getDirection() == Direction.Row && // b(*)
        HopRewriteUtils.isBinary(input, OpOp2.MULT)) {
            Hop left = input.getInput().get(0);
            Hop right = input.getInput().get(1);
            if (left.getDim1() > 1 && left.getDim2() > 1 && right.getDim1() == 1 && // MV (row vector)
            right.getDim2() > 1) {
                // create new operators
                ReorgOp trans = HopRewriteUtils.createTranspose(right);
                AggBinaryOp mmult = HopRewriteUtils.createMatrixMultiply(left, trans);
                // relink new child
                HopRewriteUtils.replaceChildReference(parent, hi, mmult, pos);
                HopRewriteUtils.cleanupUnreferenced(hi, input);
                hi = mmult;
                LOG.debug("Applied simplifyRowSumsMVMult");
            }
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) ReorgOp(org.apache.sysml.hops.ReorgOp)

Example 33 with AggUnaryOp

use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method simplifyTraceMatrixMult.

private static Hop simplifyTraceMatrixMult(Hop parent, Hop hi, int pos) {
    if (// trace()
    hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getOp() == AggOp.TRACE) {
        Hop hi2 = hi.getInput().get(0);
        if (// X%*%Y
        HopRewriteUtils.isMatrixMultiply(hi2)) {
            Hop left = hi2.getInput().get(0);
            Hop right = hi2.getInput().get(1);
            // create new operators (incl refresh size inside for transpose)
            ReorgOp trans = HopRewriteUtils.createTranspose(right);
            BinaryOp mult = HopRewriteUtils.createBinary(left, trans, OpOp2.MULT);
            AggUnaryOp sum = HopRewriteUtils.createSum(mult);
            // rehang new subdag under parent node
            HopRewriteUtils.replaceChildReference(parent, hi, sum, pos);
            HopRewriteUtils.cleanupUnreferenced(hi, hi2);
            hi = sum;
            LOG.debug("Applied simplifyTraceMatrixMult");
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) Hop(org.apache.sysml.hops.Hop) ReorgOp(org.apache.sysml.hops.ReorgOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 34 with AggUnaryOp

use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method removeUnnecessaryAggregates.

private static Hop removeUnnecessaryAggregates(Hop hi) {
    // sum(rowSums(X^2)) -> sum(X), sum(colSums(X^2)) -> sum(X)
    if (hi instanceof AggUnaryOp && hi.getInput().get(0) instanceof AggUnaryOp && ((AggUnaryOp) hi).getDirection() == Direction.RowCol && hi.getInput().get(0).getParent().size() == 1) {
        AggUnaryOp au1 = (AggUnaryOp) hi;
        AggUnaryOp au2 = (AggUnaryOp) hi.getInput().get(0);
        if ((au1.getOp() == AggOp.SUM && (au2.getOp() == AggOp.SUM || au2.getOp() == AggOp.SUM_SQ)) || (au1.getOp() == AggOp.MIN && au2.getOp() == AggOp.MIN) || (au1.getOp() == AggOp.MAX && au2.getOp() == AggOp.MAX)) {
            Hop input = au2.getInput().get(0);
            HopRewriteUtils.removeAllChildReferences(au2);
            HopRewriteUtils.replaceChildReference(au1, au2, input);
            if (au2.getOp() == AggOp.SUM_SQ)
                au1.setOp(AggOp.SUM_SQ);
            LOG.debug("Applied removeUnnecessaryAggregates (line " + hi.getBeginLine() + ").");
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) Hop(org.apache.sysml.hops.Hop)

Example 35 with AggUnaryOp

use of org.apache.sysml.hops.AggUnaryOp in project incubator-systemml by apache.

the class PlanSelectionFuseCostBased method rGetComputeCosts.

private static void rGetComputeCosts(Hop current, HashSet<Long> partition, HashMap<Long, Double> computeCosts) {
    if (computeCosts.containsKey(current.getHopID()))
        return;
    //recursively process children
    for (Hop c : current.getInput()) rGetComputeCosts(c, partition, computeCosts);
    //get costs for given hop
    double costs = 1;
    if (current instanceof UnaryOp) {
        switch(((UnaryOp) current).getOp()) {
            case ABS:
            case ROUND:
            case CEIL:
            case FLOOR:
            case SIGN:
            case SELP:
                costs = 1;
                break;
            case SPROP:
            case SQRT:
                costs = 2;
                break;
            case EXP:
                costs = 18;
                break;
            case SIGMOID:
                costs = 21;
                break;
            case LOG:
            case LOG_NZ:
                costs = 32;
                break;
            case NCOL:
            case NROW:
            case PRINT:
            case CAST_AS_BOOLEAN:
            case CAST_AS_DOUBLE:
            case CAST_AS_INT:
            case CAST_AS_MATRIX:
            case CAST_AS_SCALAR:
                costs = 1;
                break;
            case SIN:
                costs = 18;
                break;
            case COS:
                costs = 22;
                break;
            case TAN:
                costs = 42;
                break;
            case ASIN:
                costs = 93;
                break;
            case ACOS:
                costs = 103;
                break;
            case ATAN:
                costs = 40;
                break;
            case CUMSUM:
            case CUMMIN:
            case CUMMAX:
            case CUMPROD:
                costs = 1;
                break;
            default:
                LOG.warn("Cost model not " + "implemented yet for: " + ((UnaryOp) current).getOp());
        }
    } else if (current instanceof BinaryOp) {
        switch(((BinaryOp) current).getOp()) {
            case MULT:
            case PLUS:
            case MINUS:
            case MIN:
            case MAX:
            case AND:
            case OR:
            case EQUAL:
            case NOTEQUAL:
            case LESS:
            case LESSEQUAL:
            case GREATER:
            case GREATEREQUAL:
            case CBIND:
            case RBIND:
                costs = 1;
                break;
            case INTDIV:
                costs = 6;
                break;
            case MODULUS:
                costs = 8;
                break;
            case DIV:
                costs = 22;
                break;
            case LOG:
            case LOG_NZ:
                costs = 32;
                break;
            case POW:
                costs = (HopRewriteUtils.isLiteralOfValue(current.getInput().get(1), 2) ? 1 : 16);
                break;
            case MINUS_NZ:
            case MINUS1_MULT:
                costs = 2;
                break;
            case CENTRALMOMENT:
                int type = (int) (current.getInput().get(1) instanceof LiteralOp ? HopRewriteUtils.getIntValueSafe((LiteralOp) current.getInput().get(1)) : 2);
                switch(type) {
                    //count
                    case 0:
                        costs = 1;
                        break;
                    //mean
                    case 1:
                        costs = 8;
                        break;
                    //cm2
                    case 2:
                        costs = 16;
                        break;
                    //cm3
                    case 3:
                        costs = 31;
                        break;
                    //cm4
                    case 4:
                        costs = 51;
                        break;
                    //variance
                    case 5:
                        costs = 16;
                        break;
                }
                break;
            case COVARIANCE:
                costs = 23;
                break;
            default:
                LOG.warn("Cost model not " + "implemented yet for: " + ((BinaryOp) current).getOp());
        }
    } else if (current instanceof TernaryOp) {
        switch(((TernaryOp) current).getOp()) {
            case PLUS_MULT:
            case MINUS_MULT:
                costs = 2;
                break;
            case CTABLE:
                costs = 3;
                break;
            case CENTRALMOMENT:
                int type = (int) (current.getInput().get(1) instanceof LiteralOp ? HopRewriteUtils.getIntValueSafe((LiteralOp) current.getInput().get(1)) : 2);
                switch(type) {
                    //count
                    case 0:
                        costs = 2;
                        break;
                    //mean
                    case 1:
                        costs = 9;
                        break;
                    //cm2
                    case 2:
                        costs = 17;
                        break;
                    //cm3
                    case 3:
                        costs = 32;
                        break;
                    //cm4
                    case 4:
                        costs = 52;
                        break;
                    //variance
                    case 5:
                        costs = 17;
                        break;
                }
                break;
            case COVARIANCE:
                costs = 23;
                break;
            default:
                LOG.warn("Cost model not " + "implemented yet for: " + ((TernaryOp) current).getOp());
        }
    } else if (current instanceof ParameterizedBuiltinOp) {
        costs = 1;
    } else if (current instanceof IndexingOp) {
        costs = 1;
    } else if (current instanceof ReorgOp) {
        costs = 1;
    } else if (current instanceof AggBinaryOp) {
        //matrix vector
        costs = 2;
    } else if (current instanceof AggUnaryOp) {
        switch(((AggUnaryOp) current).getOp()) {
            case SUM:
                costs = 4;
                break;
            case SUM_SQ:
                costs = 5;
                break;
            case MIN:
            case MAX:
                costs = 1;
                break;
            default:
                LOG.warn("Cost model not " + "implemented yet for: " + ((AggUnaryOp) current).getOp());
        }
    }
    computeCosts.put(current.getHopID(), costs);
}
Also used : ParameterizedBuiltinOp(org.apache.sysml.hops.ParameterizedBuiltinOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) IndexingOp(org.apache.sysml.hops.IndexingOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) ReorgOp(org.apache.sysml.hops.ReorgOp) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp) TernaryOp(org.apache.sysml.hops.TernaryOp)

Aggregations

AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)36 Hop (org.apache.sysml.hops.Hop)33 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)19 LiteralOp (org.apache.sysml.hops.LiteralOp)14 BinaryOp (org.apache.sysml.hops.BinaryOp)12 ReorgOp (org.apache.sysml.hops.ReorgOp)11 UnaryOp (org.apache.sysml.hops.UnaryOp)11 IndexingOp (org.apache.sysml.hops.IndexingOp)7 MemoTableEntry (org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry)6 ArrayList (java.util.ArrayList)5 TernaryOp (org.apache.sysml.hops.TernaryOp)5 DataOp (org.apache.sysml.hops.DataOp)4 ParameterizedBuiltinOp (org.apache.sysml.hops.ParameterizedBuiltinOp)4 CNode (org.apache.sysml.hops.codegen.cplan.CNode)3 CNodeBinary (org.apache.sysml.hops.codegen.cplan.CNodeBinary)3 CNodeData (org.apache.sysml.hops.codegen.cplan.CNodeData)3 CNodeUnary (org.apache.sysml.hops.codegen.cplan.CNodeUnary)3 OpOp2 (org.apache.sysml.hops.Hop.OpOp2)2 QuaternaryOp (org.apache.sysml.hops.QuaternaryOp)2 CNodeTernary (org.apache.sysml.hops.codegen.cplan.CNodeTernary)2