Search in sources :

Example 6 with TernaryOp

use of org.apache.sysml.hops.TernaryOp in project systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyTableSeqExpand.

private static Hop simplifyTableSeqExpand(Hop parent, Hop hi, int pos) {
    // note: this rewrite supports both left/right sequence
    if (// table without weights
    hi instanceof TernaryOp && hi.getInput().size() == 5 && // i.e., weight of 1
    HopRewriteUtils.isLiteralOfValue(hi.getInput().get(2), 1)) {
        Hop first = hi.getInput().get(0);
        Hop second = hi.getInput().get(1);
        // pattern a: table(seq(1,nrow(v)), v, nrow(v), m, 1)
        if (HopRewriteUtils.isBasic1NSequence(first, second, true) && HopRewriteUtils.isSizeExpressionOf(hi.getInput().get(3), second, true)) {
            // setup input parameter hops
            HashMap<String, Hop> args = new HashMap<>();
            args.put("target", second);
            args.put("max", hi.getInput().get(4));
            args.put("dir", new LiteralOp("cols"));
            args.put("ignore", new LiteralOp(false));
            args.put("cast", new LiteralOp(true));
            // create new hop
            ParameterizedBuiltinOp pbop = HopRewriteUtils.createParameterizedBuiltinOp(second, args, ParamBuiltinOp.REXPAND);
            HopRewriteUtils.replaceChildReference(parent, hi, pbop, pos);
            HopRewriteUtils.cleanupUnreferenced(hi);
            hi = pbop;
            LOG.debug("Applied simplifyTableSeqExpand1 (line " + hi.getBeginLine() + ")");
        } else // pattern b: table(v, seq(1,nrow(v)), m, nrow(v))
        if (HopRewriteUtils.isBasic1NSequence(second, first, true) && HopRewriteUtils.isSizeExpressionOf(hi.getInput().get(4), first, true)) {
            // setup input parameter hops
            HashMap<String, Hop> args = new HashMap<>();
            args.put("target", first);
            args.put("max", hi.getInput().get(3));
            args.put("dir", new LiteralOp("rows"));
            args.put("ignore", new LiteralOp(false));
            args.put("cast", new LiteralOp(true));
            // create new hop
            ParameterizedBuiltinOp pbop = HopRewriteUtils.createParameterizedBuiltinOp(first, args, ParamBuiltinOp.REXPAND);
            HopRewriteUtils.replaceChildReference(parent, hi, pbop, pos);
            HopRewriteUtils.cleanupUnreferenced(hi);
            hi = pbop;
            LOG.debug("Applied simplifyTableSeqExpand2 (line " + hi.getBeginLine() + ")");
        }
    }
    return hi;
}
Also used : ParameterizedBuiltinOp(org.apache.sysml.hops.ParameterizedBuiltinOp) HashMap(java.util.HashMap) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) TernaryOp(org.apache.sysml.hops.TernaryOp)

Example 7 with TernaryOp

use of org.apache.sysml.hops.TernaryOp in project systemml by apache.

the class PlanSelectionFuseCostBased method rGetComputeCosts.

private static void rGetComputeCosts(Hop current, HashSet<Long> partition, HashMap<Long, Double> computeCosts) {
    if (computeCosts.containsKey(current.getHopID()))
        return;
    // recursively process children
    for (Hop c : current.getInput()) rGetComputeCosts(c, partition, computeCosts);
    // get costs for given hop
    double costs = 1;
    if (current instanceof UnaryOp) {
        switch(((UnaryOp) current).getOp()) {
            case ABS:
            case ROUND:
            case CEIL:
            case FLOOR:
            case SIGN:
                costs = 1;
                break;
            case SPROP:
            case SQRT:
                costs = 2;
                break;
            case EXP:
                costs = 18;
                break;
            case SIGMOID:
                costs = 21;
                break;
            case LOG:
            case LOG_NZ:
                costs = 32;
                break;
            case NCOL:
            case NROW:
            case PRINT:
            case ASSERT:
            case CAST_AS_BOOLEAN:
            case CAST_AS_DOUBLE:
            case CAST_AS_INT:
            case CAST_AS_MATRIX:
            case CAST_AS_SCALAR:
                costs = 1;
                break;
            case SIN:
                costs = 18;
                break;
            case COS:
                costs = 22;
                break;
            case TAN:
                costs = 42;
                break;
            case ASIN:
                costs = 93;
                break;
            case ACOS:
                costs = 103;
                break;
            case ATAN:
                costs = 40;
                break;
            // TODO:
            case SINH:
                costs = 93;
                break;
            case COSH:
                costs = 103;
                break;
            case TANH:
                costs = 40;
                break;
            case CUMSUM:
            case CUMMIN:
            case CUMMAX:
            case CUMPROD:
                costs = 1;
                break;
            default:
                LOG.warn("Cost model not " + "implemented yet for: " + ((UnaryOp) current).getOp());
        }
    } else if (current instanceof BinaryOp) {
        switch(((BinaryOp) current).getOp()) {
            case MULT:
            case PLUS:
            case MINUS:
            case MIN:
            case MAX:
            case AND:
            case OR:
            case EQUAL:
            case NOTEQUAL:
            case LESS:
            case LESSEQUAL:
            case GREATER:
            case GREATEREQUAL:
            case CBIND:
            case RBIND:
                costs = 1;
                break;
            case INTDIV:
                costs = 6;
                break;
            case MODULUS:
                costs = 8;
                break;
            case DIV:
                costs = 22;
                break;
            case LOG:
            case LOG_NZ:
                costs = 32;
                break;
            case POW:
                costs = (HopRewriteUtils.isLiteralOfValue(current.getInput().get(1), 2) ? 1 : 16);
                break;
            case MINUS_NZ:
            case MINUS1_MULT:
                costs = 2;
                break;
            case CENTRALMOMENT:
                int type = (int) (current.getInput().get(1) instanceof LiteralOp ? HopRewriteUtils.getIntValueSafe((LiteralOp) current.getInput().get(1)) : 2);
                switch(type) {
                    // count
                    case 0:
                        costs = 1;
                        break;
                    // mean
                    case 1:
                        costs = 8;
                        break;
                    // cm2
                    case 2:
                        costs = 16;
                        break;
                    // cm3
                    case 3:
                        costs = 31;
                        break;
                    // cm4
                    case 4:
                        costs = 51;
                        break;
                    // variance
                    case 5:
                        costs = 16;
                        break;
                }
                break;
            case COVARIANCE:
                costs = 23;
                break;
            default:
                LOG.warn("Cost model not " + "implemented yet for: " + ((BinaryOp) current).getOp());
        }
    } else if (current instanceof TernaryOp) {
        switch(((TernaryOp) current).getOp()) {
            case PLUS_MULT:
            case MINUS_MULT:
                costs = 2;
                break;
            case CTABLE:
                costs = 3;
                break;
            case CENTRALMOMENT:
                int type = (int) (current.getInput().get(1) instanceof LiteralOp ? HopRewriteUtils.getIntValueSafe((LiteralOp) current.getInput().get(1)) : 2);
                switch(type) {
                    // count
                    case 0:
                        costs = 2;
                        break;
                    // mean
                    case 1:
                        costs = 9;
                        break;
                    // cm2
                    case 2:
                        costs = 17;
                        break;
                    // cm3
                    case 3:
                        costs = 32;
                        break;
                    // cm4
                    case 4:
                        costs = 52;
                        break;
                    // variance
                    case 5:
                        costs = 17;
                        break;
                }
                break;
            case COVARIANCE:
                costs = 23;
                break;
            default:
                LOG.warn("Cost model not " + "implemented yet for: " + ((TernaryOp) current).getOp());
        }
    } else if (current instanceof ParameterizedBuiltinOp) {
        costs = 1;
    } else if (current instanceof IndexingOp) {
        costs = 1;
    } else if (current instanceof ReorgOp) {
        costs = 1;
    } else if (current instanceof AggBinaryOp) {
        // matrix vector
        costs = 2;
    } else if (current instanceof AggUnaryOp) {
        switch(((AggUnaryOp) current).getOp()) {
            case SUM:
                costs = 4;
                break;
            case SUM_SQ:
                costs = 5;
                break;
            case MIN:
            case MAX:
                costs = 1;
                break;
            default:
                LOG.warn("Cost model not " + "implemented yet for: " + ((AggUnaryOp) current).getOp());
        }
    }
    computeCosts.put(current.getHopID(), costs);
}
Also used : ParameterizedBuiltinOp(org.apache.sysml.hops.ParameterizedBuiltinOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) IndexingOp(org.apache.sysml.hops.IndexingOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) ReorgOp(org.apache.sysml.hops.ReorgOp) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp) TernaryOp(org.apache.sysml.hops.TernaryOp)

Example 8 with TernaryOp

use of org.apache.sysml.hops.TernaryOp in project incubator-systemml by apache.

the class PlanSelectionFuseCostBased method rGetComputeCosts.

private static void rGetComputeCosts(Hop current, HashSet<Long> partition, HashMap<Long, Double> computeCosts) {
    if (computeCosts.containsKey(current.getHopID()))
        return;
    //recursively process children
    for (Hop c : current.getInput()) rGetComputeCosts(c, partition, computeCosts);
    //get costs for given hop
    double costs = 1;
    if (current instanceof UnaryOp) {
        switch(((UnaryOp) current).getOp()) {
            case ABS:
            case ROUND:
            case CEIL:
            case FLOOR:
            case SIGN:
            case SELP:
                costs = 1;
                break;
            case SPROP:
            case SQRT:
                costs = 2;
                break;
            case EXP:
                costs = 18;
                break;
            case SIGMOID:
                costs = 21;
                break;
            case LOG:
            case LOG_NZ:
                costs = 32;
                break;
            case NCOL:
            case NROW:
            case PRINT:
            case CAST_AS_BOOLEAN:
            case CAST_AS_DOUBLE:
            case CAST_AS_INT:
            case CAST_AS_MATRIX:
            case CAST_AS_SCALAR:
                costs = 1;
                break;
            case SIN:
                costs = 18;
                break;
            case COS:
                costs = 22;
                break;
            case TAN:
                costs = 42;
                break;
            case ASIN:
                costs = 93;
                break;
            case ACOS:
                costs = 103;
                break;
            case ATAN:
                costs = 40;
                break;
            case CUMSUM:
            case CUMMIN:
            case CUMMAX:
            case CUMPROD:
                costs = 1;
                break;
            default:
                LOG.warn("Cost model not " + "implemented yet for: " + ((UnaryOp) current).getOp());
        }
    } else if (current instanceof BinaryOp) {
        switch(((BinaryOp) current).getOp()) {
            case MULT:
            case PLUS:
            case MINUS:
            case MIN:
            case MAX:
            case AND:
            case OR:
            case EQUAL:
            case NOTEQUAL:
            case LESS:
            case LESSEQUAL:
            case GREATER:
            case GREATEREQUAL:
            case CBIND:
            case RBIND:
                costs = 1;
                break;
            case INTDIV:
                costs = 6;
                break;
            case MODULUS:
                costs = 8;
                break;
            case DIV:
                costs = 22;
                break;
            case LOG:
            case LOG_NZ:
                costs = 32;
                break;
            case POW:
                costs = (HopRewriteUtils.isLiteralOfValue(current.getInput().get(1), 2) ? 1 : 16);
                break;
            case MINUS_NZ:
            case MINUS1_MULT:
                costs = 2;
                break;
            case CENTRALMOMENT:
                int type = (int) (current.getInput().get(1) instanceof LiteralOp ? HopRewriteUtils.getIntValueSafe((LiteralOp) current.getInput().get(1)) : 2);
                switch(type) {
                    //count
                    case 0:
                        costs = 1;
                        break;
                    //mean
                    case 1:
                        costs = 8;
                        break;
                    //cm2
                    case 2:
                        costs = 16;
                        break;
                    //cm3
                    case 3:
                        costs = 31;
                        break;
                    //cm4
                    case 4:
                        costs = 51;
                        break;
                    //variance
                    case 5:
                        costs = 16;
                        break;
                }
                break;
            case COVARIANCE:
                costs = 23;
                break;
            default:
                LOG.warn("Cost model not " + "implemented yet for: " + ((BinaryOp) current).getOp());
        }
    } else if (current instanceof TernaryOp) {
        switch(((TernaryOp) current).getOp()) {
            case PLUS_MULT:
            case MINUS_MULT:
                costs = 2;
                break;
            case CTABLE:
                costs = 3;
                break;
            case CENTRALMOMENT:
                int type = (int) (current.getInput().get(1) instanceof LiteralOp ? HopRewriteUtils.getIntValueSafe((LiteralOp) current.getInput().get(1)) : 2);
                switch(type) {
                    //count
                    case 0:
                        costs = 2;
                        break;
                    //mean
                    case 1:
                        costs = 9;
                        break;
                    //cm2
                    case 2:
                        costs = 17;
                        break;
                    //cm3
                    case 3:
                        costs = 32;
                        break;
                    //cm4
                    case 4:
                        costs = 52;
                        break;
                    //variance
                    case 5:
                        costs = 17;
                        break;
                }
                break;
            case COVARIANCE:
                costs = 23;
                break;
            default:
                LOG.warn("Cost model not " + "implemented yet for: " + ((TernaryOp) current).getOp());
        }
    } else if (current instanceof ParameterizedBuiltinOp) {
        costs = 1;
    } else if (current instanceof IndexingOp) {
        costs = 1;
    } else if (current instanceof ReorgOp) {
        costs = 1;
    } else if (current instanceof AggBinaryOp) {
        //matrix vector
        costs = 2;
    } else if (current instanceof AggUnaryOp) {
        switch(((AggUnaryOp) current).getOp()) {
            case SUM:
                costs = 4;
                break;
            case SUM_SQ:
                costs = 5;
                break;
            case MIN:
            case MAX:
                costs = 1;
                break;
            default:
                LOG.warn("Cost model not " + "implemented yet for: " + ((AggUnaryOp) current).getOp());
        }
    }
    computeCosts.put(current.getHopID(), costs);
}
Also used : ParameterizedBuiltinOp(org.apache.sysml.hops.ParameterizedBuiltinOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) IndexingOp(org.apache.sysml.hops.IndexingOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) ReorgOp(org.apache.sysml.hops.ReorgOp) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp) TernaryOp(org.apache.sysml.hops.TernaryOp)

Example 9 with TernaryOp

use of org.apache.sysml.hops.TernaryOp in project incubator-systemml by apache.

the class PlanSelectionFuseCostBased method rGetComputeCosts.

private static void rGetComputeCosts(Hop current, HashSet<Long> partition, HashMap<Long, Double> computeCosts) {
    if (computeCosts.containsKey(current.getHopID()))
        return;
    // recursively process children
    for (Hop c : current.getInput()) rGetComputeCosts(c, partition, computeCosts);
    // get costs for given hop
    double costs = 1;
    if (current instanceof UnaryOp) {
        switch(((UnaryOp) current).getOp()) {
            case ABS:
            case ROUND:
            case CEIL:
            case FLOOR:
            case SIGN:
                costs = 1;
                break;
            case SPROP:
            case SQRT:
                costs = 2;
                break;
            case EXP:
                costs = 18;
                break;
            case SIGMOID:
                costs = 21;
                break;
            case LOG:
            case LOG_NZ:
                costs = 32;
                break;
            case NCOL:
            case NROW:
            case PRINT:
            case ASSERT:
            case CAST_AS_BOOLEAN:
            case CAST_AS_DOUBLE:
            case CAST_AS_INT:
            case CAST_AS_MATRIX:
            case CAST_AS_SCALAR:
                costs = 1;
                break;
            case SIN:
                costs = 18;
                break;
            case COS:
                costs = 22;
                break;
            case TAN:
                costs = 42;
                break;
            case ASIN:
                costs = 93;
                break;
            case ACOS:
                costs = 103;
                break;
            case ATAN:
                costs = 40;
                break;
            // TODO:
            case SINH:
                costs = 93;
                break;
            case COSH:
                costs = 103;
                break;
            case TANH:
                costs = 40;
                break;
            case CUMSUM:
            case CUMMIN:
            case CUMMAX:
            case CUMPROD:
                costs = 1;
                break;
            default:
                LOG.warn("Cost model not " + "implemented yet for: " + ((UnaryOp) current).getOp());
        }
    } else if (current instanceof BinaryOp) {
        switch(((BinaryOp) current).getOp()) {
            case MULT:
            case PLUS:
            case MINUS:
            case MIN:
            case MAX:
            case AND:
            case OR:
            case EQUAL:
            case NOTEQUAL:
            case LESS:
            case LESSEQUAL:
            case GREATER:
            case GREATEREQUAL:
            case CBIND:
            case RBIND:
                costs = 1;
                break;
            case INTDIV:
                costs = 6;
                break;
            case MODULUS:
                costs = 8;
                break;
            case DIV:
                costs = 22;
                break;
            case LOG:
            case LOG_NZ:
                costs = 32;
                break;
            case POW:
                costs = (HopRewriteUtils.isLiteralOfValue(current.getInput().get(1), 2) ? 1 : 16);
                break;
            case MINUS_NZ:
            case MINUS1_MULT:
                costs = 2;
                break;
            case CENTRALMOMENT:
                int type = (int) (current.getInput().get(1) instanceof LiteralOp ? HopRewriteUtils.getIntValueSafe((LiteralOp) current.getInput().get(1)) : 2);
                switch(type) {
                    // count
                    case 0:
                        costs = 1;
                        break;
                    // mean
                    case 1:
                        costs = 8;
                        break;
                    // cm2
                    case 2:
                        costs = 16;
                        break;
                    // cm3
                    case 3:
                        costs = 31;
                        break;
                    // cm4
                    case 4:
                        costs = 51;
                        break;
                    // variance
                    case 5:
                        costs = 16;
                        break;
                }
                break;
            case COVARIANCE:
                costs = 23;
                break;
            default:
                LOG.warn("Cost model not " + "implemented yet for: " + ((BinaryOp) current).getOp());
        }
    } else if (current instanceof TernaryOp) {
        switch(((TernaryOp) current).getOp()) {
            case PLUS_MULT:
            case MINUS_MULT:
                costs = 2;
                break;
            case CTABLE:
                costs = 3;
                break;
            case CENTRALMOMENT:
                int type = (int) (current.getInput().get(1) instanceof LiteralOp ? HopRewriteUtils.getIntValueSafe((LiteralOp) current.getInput().get(1)) : 2);
                switch(type) {
                    // count
                    case 0:
                        costs = 2;
                        break;
                    // mean
                    case 1:
                        costs = 9;
                        break;
                    // cm2
                    case 2:
                        costs = 17;
                        break;
                    // cm3
                    case 3:
                        costs = 32;
                        break;
                    // cm4
                    case 4:
                        costs = 52;
                        break;
                    // variance
                    case 5:
                        costs = 17;
                        break;
                }
                break;
            case COVARIANCE:
                costs = 23;
                break;
            default:
                LOG.warn("Cost model not " + "implemented yet for: " + ((TernaryOp) current).getOp());
        }
    } else if (current instanceof ParameterizedBuiltinOp) {
        costs = 1;
    } else if (current instanceof IndexingOp) {
        costs = 1;
    } else if (current instanceof ReorgOp) {
        costs = 1;
    } else if (current instanceof AggBinaryOp) {
        // matrix vector
        costs = 2;
    } else if (current instanceof AggUnaryOp) {
        switch(((AggUnaryOp) current).getOp()) {
            case SUM:
                costs = 4;
                break;
            case SUM_SQ:
                costs = 5;
                break;
            case MIN:
            case MAX:
                costs = 1;
                break;
            default:
                LOG.warn("Cost model not " + "implemented yet for: " + ((AggUnaryOp) current).getOp());
        }
    }
    computeCosts.put(current.getHopID(), costs);
}
Also used : ParameterizedBuiltinOp(org.apache.sysml.hops.ParameterizedBuiltinOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) IndexingOp(org.apache.sysml.hops.IndexingOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) ReorgOp(org.apache.sysml.hops.ReorgOp) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp) TernaryOp(org.apache.sysml.hops.TernaryOp)

Example 10 with TernaryOp

use of org.apache.sysml.hops.TernaryOp in project incubator-systemml by apache.

the class TemplateRow method rConstructCplan.

private void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, HashMap<String, Hop> inHops2, boolean compileLiterals) {
    // memoization for common subexpression elimination and to avoid redundant work
    if (tmp.containsKey(hop.getHopID()))
        return;
    // recursively process required childs
    MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.ROW, TemplateType.CELL);
    for (int i = 0; i < hop.getInput().size(); i++) {
        Hop c = hop.getInput().get(i);
        if (me != null && me.isPlanRef(i))
            rConstructCplan(c, memo, tmp, inHops, inHops2, compileLiterals);
        else {
            CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);
            tmp.put(c.getHopID(), cdata);
            inHops.add(c);
        }
    }
    // construct cnode for current hop
    CNode out = null;
    if (hop instanceof AggUnaryOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        if (((AggUnaryOp) hop).getDirection() == Direction.Row && HopRewriteUtils.isAggUnaryOp(hop, SUPPORTED_ROW_AGG)) {
            if (hop.getInput().get(0).getDim2() == 1)
                out = (cdata1.getDataType() == DataType.SCALAR) ? cdata1 : new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
            else {
                String opcode = "ROW_" + ((AggUnaryOp) hop).getOp().name().toUpperCase() + "S";
                out = new CNodeUnary(cdata1, UnaryType.valueOf(opcode));
                if (cdata1 instanceof CNodeData && !inHops2.containsKey("X"))
                    inHops2.put("X", hop.getInput().get(0));
            }
        } else if (((AggUnaryOp) hop).getDirection() == Direction.Col && ((AggUnaryOp) hop).getOp() == AggOp.SUM) {
            // vector add without temporary copy
            if (cdata1 instanceof CNodeBinary && ((CNodeBinary) cdata1).getType().isVectorScalarPrimitive())
                out = new CNodeBinary(cdata1.getInput().get(0), cdata1.getInput().get(1), ((CNodeBinary) cdata1).getType().getVectorAddPrimitive());
            else
                out = cdata1;
        } else if (((AggUnaryOp) hop).getDirection() == Direction.RowCol && ((AggUnaryOp) hop).getOp() == AggOp.SUM) {
            out = (cdata1.getDataType().isMatrix()) ? new CNodeUnary(cdata1, UnaryType.ROW_SUMS) : cdata1;
        }
    } else if (hop instanceof AggBinaryOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
        if (HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))) {
            // correct input under transpose
            cdata1 = TemplateUtils.skipTranspose(cdata1, hop.getInput().get(0), tmp, compileLiterals);
            inHops.remove(hop.getInput().get(0));
            if (cdata1 instanceof CNodeData)
                inHops.add(hop.getInput().get(0).getInput().get(0));
            // note: vectorMultAdd applicable to vector-scalar, and vector-vector
            if (hop.getInput().get(1).getDim2() == 1)
                out = new CNodeBinary(cdata1, cdata2, BinType.VECT_MULT_ADD);
            else {
                out = new CNodeBinary(cdata1, cdata2, BinType.VECT_OUTERMULT_ADD);
                if (!inHops2.containsKey("B1")) {
                    // incl modification of X for consistency
                    if (cdata1 instanceof CNodeData)
                        inHops2.put("X", hop.getInput().get(0).getInput().get(0));
                    inHops2.put("B1", hop.getInput().get(1));
                }
            }
            if (!inHops2.containsKey("X"))
                inHops2.put("X", hop.getInput().get(0).getInput().get(0));
        } else {
            if (hop.getInput().get(0).getDim2() == 1 && hop.getInput().get(1).getDim2() == 1)
                out = new CNodeBinary((cdata1.getDataType() == DataType.SCALAR) ? cdata1 : new CNodeUnary(cdata1, UnaryType.LOOKUP0), (cdata2.getDataType() == DataType.SCALAR) ? cdata2 : new CNodeUnary(cdata2, UnaryType.LOOKUP0), BinType.MULT);
            else if (hop.getInput().get(1).getDim2() == 1) {
                out = new CNodeBinary(cdata1, cdata2, BinType.DOT_PRODUCT);
                inHops2.put("X", hop.getInput().get(0));
            } else {
                out = new CNodeBinary(cdata1, cdata2, BinType.VECT_MATRIXMULT);
                inHops2.put("X", hop.getInput().get(0));
                inHops2.put("B1", hop.getInput().get(1));
            }
        }
    } else if (HopRewriteUtils.isTransposeOperation(hop)) {
        out = TemplateUtils.skipTranspose(tmp.get(hop.getHopID()), hop, tmp, compileLiterals);
        if (out instanceof CNodeData && !inHops.contains(hop.getInput().get(0)))
            inHops.add(hop.getInput().get(0));
    } else if (hop instanceof UnaryOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        // if one input is a matrix then we need to do vector by scalar operations
        if (hop.getInput().get(0).getDim1() >= 1 && hop.getInput().get(0).getDim2() > 1 || (!hop.dimsKnown() && cdata1.getDataType() == DataType.MATRIX)) {
            if (HopRewriteUtils.isUnary(hop, SUPPORTED_VECT_UNARY)) {
                String opname = "VECT_" + ((UnaryOp) hop).getOp().name();
                out = new CNodeUnary(cdata1, UnaryType.valueOf(opname));
                if (cdata1 instanceof CNodeData && !inHops2.containsKey("X"))
                    inHops2.put("X", hop.getInput().get(0));
            } else
                throw new RuntimeException("Unsupported unary matrix " + "operation: " + ((UnaryOp) hop).getOp().name());
        } else // general scalar case
        {
            cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
            String primitiveOpName = ((UnaryOp) hop).getOp().toString();
            out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName));
        }
    } else if (HopRewriteUtils.isBinary(hop, OpOp2.CBIND)) {
        // special case for cbind with zeros
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        CNode cdata2 = null;
        if (HopRewriteUtils.isDataGenOpWithConstantValue(hop.getInput().get(1))) {
            cdata2 = TemplateUtils.createCNodeData(HopRewriteUtils.getDataGenOpConstantValue(hop.getInput().get(1)), true);
            // rm 0-matrix
            inHops.remove(hop.getInput().get(1));
        } else {
            cdata2 = tmp.get(hop.getInput().get(1).getHopID());
            cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
        }
        out = new CNodeBinary(cdata1, cdata2, BinType.VECT_CBIND);
        if (cdata1 instanceof CNodeData && !inHops2.containsKey("X"))
            inHops2.put("X", hop.getInput().get(0));
    } else if (hop instanceof BinaryOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
        // if one input is a matrix then we need to do vector by scalar operations
        if ((hop.getInput().get(0).getDim1() >= 1 && hop.getInput().get(0).getDim2() > 1) || (hop.getInput().get(1).getDim1() >= 1 && hop.getInput().get(1).getDim2() > 1) || (!(hop.dimsKnown() && hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown()) && // not a known vector output
        (hop.getDim2() != 1) && (cdata1.getDataType().isMatrix() || cdata2.getDataType().isMatrix()))) {
            if (HopRewriteUtils.isBinary(hop, SUPPORTED_VECT_BINARY)) {
                if (TemplateUtils.isMatrix(cdata1) && (TemplateUtils.isMatrix(cdata2) || TemplateUtils.isRowVector(cdata2))) {
                    String opname = "VECT_" + ((BinaryOp) hop).getOp().name();
                    out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(opname));
                } else {
                    String opname = "VECT_" + ((BinaryOp) hop).getOp().name() + "_SCALAR";
                    if (TemplateUtils.isColVector(cdata1))
                        cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
                    if (TemplateUtils.isColVector(cdata2))
                        cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R);
                    out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(opname));
                }
                if (cdata1 instanceof CNodeData && !inHops2.containsKey("X") && !(cdata1.getDataType() == DataType.SCALAR)) {
                    inHops2.put("X", hop.getInput().get(0));
                }
            } else
                throw new RuntimeException("Unsupported binary matrix " + "operation: " + ((BinaryOp) hop).getOp().name());
        } else // one input is a vector/scalar other is a scalar
        {
            String primitiveOpName = ((BinaryOp) hop).getOp().toString();
            if (TemplateUtils.isColVector(cdata1))
                cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
            if (// vector or vector can be inferred from lhs
            TemplateUtils.isColVector(cdata2) || (TemplateUtils.isColVector(hop.getInput().get(0)) && cdata2 instanceof CNodeData && hop.getInput().get(1).getDataType().isMatrix()))
                cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R);
            out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName));
        }
    } else if (hop instanceof TernaryOp) {
        TernaryOp top = (TernaryOp) hop;
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
        CNode cdata3 = tmp.get(hop.getInput().get(2).getHopID());
        // add lookups if required
        cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
        cdata3 = TemplateUtils.wrapLookupIfNecessary(cdata3, hop.getInput().get(2));
        // construct ternary cnode, primitive operation derived from OpOp3
        out = new CNodeTernary(cdata1, cdata2, cdata3, TernaryType.valueOf(top.getOp().toString()));
    } else if (HopRewriteUtils.isNary(hop, OpOpN.CBIND)) {
        CNode[] inputs = new CNode[hop.getInput().size()];
        for (int i = 0; i < hop.getInput().size(); i++) {
            Hop c = hop.getInput().get(i);
            CNode cdata = tmp.get(c.getHopID());
            if (TemplateUtils.isColVector(cdata) || TemplateUtils.isRowVector(cdata))
                cdata = TemplateUtils.wrapLookupIfNecessary(cdata, c);
            inputs[i] = cdata;
            if (i == 0 && cdata instanceof CNodeData && !inHops2.containsKey("X"))
                inHops2.put("X", c);
        }
        out = new CNodeNary(inputs, NaryType.VECT_CBIND);
    } else if (hop instanceof ParameterizedBuiltinOp) {
        CNode cdata1 = tmp.get(((ParameterizedBuiltinOp) hop).getTargetHop().getHopID());
        cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
        CNode cdata2 = tmp.get(((ParameterizedBuiltinOp) hop).getParameterHop("pattern").getHopID());
        CNode cdata3 = tmp.get(((ParameterizedBuiltinOp) hop).getParameterHop("replacement").getHopID());
        TernaryType ttype = (cdata2.isLiteral() && cdata2.getVarname().equals("Double.NaN")) ? TernaryType.REPLACE_NAN : TernaryType.REPLACE;
        out = new CNodeTernary(cdata1, cdata2, cdata3, ttype);
    } else if (hop instanceof IndexingOp) {
        CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
        out = new CNodeTernary(cdata1, TemplateUtils.createCNodeData(new LiteralOp(hop.getInput().get(0).getDim2()), true), TemplateUtils.createCNodeData(hop.getInput().get(4), true), (hop.getDim2() != 1) ? TernaryType.LOOKUP_RVECT1 : TernaryType.LOOKUP_RC1);
    }
    if (out == null) {
        throw new RuntimeException(hop.getHopID() + " " + hop.getOpString());
    }
    if (out.getDataType().isMatrix()) {
        out.setNumRows(hop.getDim1());
        out.setNumCols(hop.getDim2());
    }
    tmp.put(hop.getHopID(), out);
}
Also used : TernaryType(org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType) CNodeData(org.apache.sysml.hops.codegen.cplan.CNodeData) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) CNodeTernary(org.apache.sysml.hops.codegen.cplan.CNodeTernary) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) CNodeBinary(org.apache.sysml.hops.codegen.cplan.CNodeBinary) CNodeNary(org.apache.sysml.hops.codegen.cplan.CNodeNary) TernaryOp(org.apache.sysml.hops.TernaryOp) CNode(org.apache.sysml.hops.codegen.cplan.CNode) ParameterizedBuiltinOp(org.apache.sysml.hops.ParameterizedBuiltinOp) CNodeUnary(org.apache.sysml.hops.codegen.cplan.CNodeUnary) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) IndexingOp(org.apache.sysml.hops.IndexingOp) MemoTableEntry(org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Aggregations

TernaryOp (org.apache.sysml.hops.TernaryOp)19 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)15 Hop (org.apache.sysml.hops.Hop)15 LiteralOp (org.apache.sysml.hops.LiteralOp)13 ParameterizedBuiltinOp (org.apache.sysml.hops.ParameterizedBuiltinOp)13 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)11 BinaryOp (org.apache.sysml.hops.BinaryOp)11 UnaryOp (org.apache.sysml.hops.UnaryOp)11 IndexingOp (org.apache.sysml.hops.IndexingOp)7 ReorgOp (org.apache.sysml.hops.ReorgOp)7 HashMap (java.util.HashMap)4 DataOp (org.apache.sysml.hops.DataOp)4 CNode (org.apache.sysml.hops.codegen.cplan.CNode)4 CNodeBinary (org.apache.sysml.hops.codegen.cplan.CNodeBinary)4 CNodeData (org.apache.sysml.hops.codegen.cplan.CNodeData)4 CNodeTernary (org.apache.sysml.hops.codegen.cplan.CNodeTernary)4 TernaryType (org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType)4 CNodeUnary (org.apache.sysml.hops.codegen.cplan.CNodeUnary)4 MemoTableEntry (org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry)4 ArrayList (java.util.ArrayList)2