Search in sources :

Example 21 with IndexingOp

use of org.apache.sysml.hops.IndexingOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method simplifySlicedMatrixMult.

private static Hop simplifySlicedMatrixMult(Hop parent, Hop hi, int pos) {
    // e.g., (X%*%Y)[1,1] -> X[1,] %*% Y[,1]
    if (hi instanceof IndexingOp && ((IndexingOp) hi).isRowLowerEqualsUpper() && ((IndexingOp) hi).isColLowerEqualsUpper() && // rix is single mm consumer
    hi.getInput().get(0).getParent().size() == 1 && HopRewriteUtils.isMatrixMultiply(hi.getInput().get(0))) {
        Hop mm = hi.getInput().get(0);
        Hop X = mm.getInput().get(0);
        Hop Y = mm.getInput().get(1);
        // rl==ru
        Hop rowExpr = hi.getInput().get(1);
        // cl==cu
        Hop colExpr = hi.getInput().get(3);
        HopRewriteUtils.removeAllChildReferences(mm);
        // create new indexing operations
        IndexingOp ix1 = new IndexingOp("tmp1", DataType.MATRIX, ValueType.DOUBLE, X, rowExpr, rowExpr, new LiteralOp(1), HopRewriteUtils.createValueHop(X, false), true, false);
        ix1.setOutputBlocksizes(X.getRowsInBlock(), X.getColsInBlock());
        ix1.refreshSizeInformation();
        IndexingOp ix2 = new IndexingOp("tmp2", DataType.MATRIX, ValueType.DOUBLE, Y, new LiteralOp(1), HopRewriteUtils.createValueHop(Y, true), colExpr, colExpr, false, true);
        ix2.setOutputBlocksizes(Y.getRowsInBlock(), Y.getColsInBlock());
        ix2.refreshSizeInformation();
        // rewire matrix mult over ix1 and ix2
        HopRewriteUtils.addChildReference(mm, ix1, 0);
        HopRewriteUtils.addChildReference(mm, ix2, 1);
        mm.refreshSizeInformation();
        hi = mm;
        LOG.debug("Applied simplifySlicedMatrixMult");
    }
    return hi;
}
Also used : IndexingOp(org.apache.sysml.hops.IndexingOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp)

Example 22 with IndexingOp

use of org.apache.sysml.hops.IndexingOp in project incubator-systemml by apache.

the class PlanSelectionFuseCostBased method rGetComputeCosts.

private static void rGetComputeCosts(Hop current, HashSet<Long> partition, HashMap<Long, Double> computeCosts) {
    if (computeCosts.containsKey(current.getHopID()))
        return;
    //recursively process children
    for (Hop c : current.getInput()) rGetComputeCosts(c, partition, computeCosts);
    //get costs for given hop
    double costs = 1;
    if (current instanceof UnaryOp) {
        switch(((UnaryOp) current).getOp()) {
            case ABS:
            case ROUND:
            case CEIL:
            case FLOOR:
            case SIGN:
            case SELP:
                costs = 1;
                break;
            case SPROP:
            case SQRT:
                costs = 2;
                break;
            case EXP:
                costs = 18;
                break;
            case SIGMOID:
                costs = 21;
                break;
            case LOG:
            case LOG_NZ:
                costs = 32;
                break;
            case NCOL:
            case NROW:
            case PRINT:
            case CAST_AS_BOOLEAN:
            case CAST_AS_DOUBLE:
            case CAST_AS_INT:
            case CAST_AS_MATRIX:
            case CAST_AS_SCALAR:
                costs = 1;
                break;
            case SIN:
                costs = 18;
                break;
            case COS:
                costs = 22;
                break;
            case TAN:
                costs = 42;
                break;
            case ASIN:
                costs = 93;
                break;
            case ACOS:
                costs = 103;
                break;
            case ATAN:
                costs = 40;
                break;
            case CUMSUM:
            case CUMMIN:
            case CUMMAX:
            case CUMPROD:
                costs = 1;
                break;
            default:
                LOG.warn("Cost model not " + "implemented yet for: " + ((UnaryOp) current).getOp());
        }
    } else if (current instanceof BinaryOp) {
        switch(((BinaryOp) current).getOp()) {
            case MULT:
            case PLUS:
            case MINUS:
            case MIN:
            case MAX:
            case AND:
            case OR:
            case EQUAL:
            case NOTEQUAL:
            case LESS:
            case LESSEQUAL:
            case GREATER:
            case GREATEREQUAL:
            case CBIND:
            case RBIND:
                costs = 1;
                break;
            case INTDIV:
                costs = 6;
                break;
            case MODULUS:
                costs = 8;
                break;
            case DIV:
                costs = 22;
                break;
            case LOG:
            case LOG_NZ:
                costs = 32;
                break;
            case POW:
                costs = (HopRewriteUtils.isLiteralOfValue(current.getInput().get(1), 2) ? 1 : 16);
                break;
            case MINUS_NZ:
            case MINUS1_MULT:
                costs = 2;
                break;
            case CENTRALMOMENT:
                int type = (int) (current.getInput().get(1) instanceof LiteralOp ? HopRewriteUtils.getIntValueSafe((LiteralOp) current.getInput().get(1)) : 2);
                switch(type) {
                    //count
                    case 0:
                        costs = 1;
                        break;
                    //mean
                    case 1:
                        costs = 8;
                        break;
                    //cm2
                    case 2:
                        costs = 16;
                        break;
                    //cm3
                    case 3:
                        costs = 31;
                        break;
                    //cm4
                    case 4:
                        costs = 51;
                        break;
                    //variance
                    case 5:
                        costs = 16;
                        break;
                }
                break;
            case COVARIANCE:
                costs = 23;
                break;
            default:
                LOG.warn("Cost model not " + "implemented yet for: " + ((BinaryOp) current).getOp());
        }
    } else if (current instanceof TernaryOp) {
        switch(((TernaryOp) current).getOp()) {
            case PLUS_MULT:
            case MINUS_MULT:
                costs = 2;
                break;
            case CTABLE:
                costs = 3;
                break;
            case CENTRALMOMENT:
                int type = (int) (current.getInput().get(1) instanceof LiteralOp ? HopRewriteUtils.getIntValueSafe((LiteralOp) current.getInput().get(1)) : 2);
                switch(type) {
                    //count
                    case 0:
                        costs = 2;
                        break;
                    //mean
                    case 1:
                        costs = 9;
                        break;
                    //cm2
                    case 2:
                        costs = 17;
                        break;
                    //cm3
                    case 3:
                        costs = 32;
                        break;
                    //cm4
                    case 4:
                        costs = 52;
                        break;
                    //variance
                    case 5:
                        costs = 17;
                        break;
                }
                break;
            case COVARIANCE:
                costs = 23;
                break;
            default:
                LOG.warn("Cost model not " + "implemented yet for: " + ((TernaryOp) current).getOp());
        }
    } else if (current instanceof ParameterizedBuiltinOp) {
        costs = 1;
    } else if (current instanceof IndexingOp) {
        costs = 1;
    } else if (current instanceof ReorgOp) {
        costs = 1;
    } else if (current instanceof AggBinaryOp) {
        //matrix vector
        costs = 2;
    } else if (current instanceof AggUnaryOp) {
        switch(((AggUnaryOp) current).getOp()) {
            case SUM:
                costs = 4;
                break;
            case SUM_SQ:
                costs = 5;
                break;
            case MIN:
            case MAX:
                costs = 1;
                break;
            default:
                LOG.warn("Cost model not " + "implemented yet for: " + ((AggUnaryOp) current).getOp());
        }
    }
    computeCosts.put(current.getHopID(), costs);
}
Also used : ParameterizedBuiltinOp(org.apache.sysml.hops.ParameterizedBuiltinOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) IndexingOp(org.apache.sysml.hops.IndexingOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) ReorgOp(org.apache.sysml.hops.ReorgOp) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp) TernaryOp(org.apache.sysml.hops.TernaryOp)

Aggregations

IndexingOp (org.apache.sysml.hops.IndexingOp)22 Hop (org.apache.sysml.hops.Hop)18 LiteralOp (org.apache.sysml.hops.LiteralOp)17 LeftIndexingOp (org.apache.sysml.hops.LeftIndexingOp)12 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)10 DataOp (org.apache.sysml.hops.DataOp)8 UnaryOp (org.apache.sysml.hops.UnaryOp)8 BinaryOp (org.apache.sysml.hops.BinaryOp)6 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)5 ParameterizedBuiltinOp (org.apache.sysml.hops.ParameterizedBuiltinOp)4 TernaryOp (org.apache.sysml.hops.TernaryOp)4 ForStatementBlock (org.apache.sysml.parser.ForStatementBlock)4 IfStatementBlock (org.apache.sysml.parser.IfStatementBlock)4 StatementBlock (org.apache.sysml.parser.StatementBlock)4 WhileStatementBlock (org.apache.sysml.parser.WhileStatementBlock)4 ArrayList (java.util.ArrayList)3 ReorgOp (org.apache.sysml.hops.ReorgOp)3 MatrixObject (org.apache.sysml.runtime.controlprogram.caching.MatrixObject)3 CNode (org.apache.sysml.hops.codegen.cplan.CNode)2 CNodeBinary (org.apache.sysml.hops.codegen.cplan.CNodeBinary)2