Search in sources :

Example 6 with LiteralOp

use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method removeUnnecessaryMinus.

private Hop removeUnnecessaryMinus(Hop parent, Hop hi, int pos) throws HopsException {
    if (hi.getDataType() == DataType.MATRIX && hi instanceof BinaryOp && //first minus
    ((BinaryOp) hi).getOp() == OpOp2.MINUS && hi.getInput().get(0) instanceof LiteralOp && ((LiteralOp) hi.getInput().get(0)).getDoubleValue() == 0) {
        Hop hi2 = hi.getInput().get(1);
        if (hi2.getDataType() == DataType.MATRIX && hi2 instanceof BinaryOp && //second minus
        ((BinaryOp) hi2).getOp() == OpOp2.MINUS && hi2.getInput().get(0) instanceof LiteralOp && ((LiteralOp) hi2.getInput().get(0)).getDoubleValue() == 0) {
            Hop hi3 = hi2.getInput().get(1);
            //remove unnecessary chain of -(-())
            HopRewriteUtils.replaceChildReference(parent, hi, hi3, pos);
            HopRewriteUtils.cleanupUnreferenced(hi, hi2);
            hi = hi3;
            LOG.debug("Applied removeUnecessaryMinus");
        }
    }
    return hi;
}
Also used : Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 7 with LiteralOp

use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method fuseLogNzUnaryOperation.

private Hop fuseLogNzUnaryOperation(Hop parent, Hop hi, int pos) throws HopsException {
    //memory estimate and to prevent dense intermediates if X is ultra sparse  
    if (HopRewriteUtils.isBinary(hi, OpOp2.MULT) && hi.getInput().get(0).getDataType() == DataType.MATRIX && hi.getInput().get(1).getDataType() == DataType.MATRIX && HopRewriteUtils.isUnary(hi.getInput().get(1), OpOp1.LOG)) {
        Hop pred = hi.getInput().get(0);
        Hop X = hi.getInput().get(1).getInput().get(0);
        if (HopRewriteUtils.isBinary(pred, OpOp2.NOTEQUAL) && //depend on common subexpression elimination
        pred.getInput().get(0) == X && pred.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) pred.getInput().get(1)) == 0) {
            Hop hnew = HopRewriteUtils.createUnary(X, OpOp1.LOG_NZ);
            //relink new hop into original position
            HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
            hi = hnew;
            LOG.debug("Applied fuseLogNzUnaryOperation (line " + hi.getBeginLine() + ").");
        }
    }
    return hi;
}
Also used : Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp)

Example 8 with LiteralOp

use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationStatic method fuseDatagenAndMinusOperation.

private Hop fuseDatagenAndMinusOperation(Hop hi) throws HopsException {
    if (hi instanceof BinaryOp) {
        BinaryOp bop = (BinaryOp) hi;
        Hop left = bop.getInput().get(0);
        Hop right = bop.getInput().get(1);
        if (right instanceof DataGenOp && ((DataGenOp) right).getOp() == DataGenMethod.RAND && left instanceof LiteralOp && ((LiteralOp) left).getDoubleValue() == 0.0) {
            DataGenOp inputGen = (DataGenOp) right;
            HashMap<String, Integer> params = inputGen.getParamIndexMap();
            Hop pdf = right.getInput().get(params.get(DataExpression.RAND_PDF));
            int ixMin = params.get(DataExpression.RAND_MIN);
            int ixMax = params.get(DataExpression.RAND_MAX);
            Hop min = right.getInput().get(ixMin);
            Hop max = right.getInput().get(ixMax);
            //apply rewrite under additional conditions (for simplicity)
            if (inputGen.getParent().size() == 1 && min instanceof LiteralOp && max instanceof LiteralOp && pdf instanceof LiteralOp && DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp) pdf).getStringValue())) {
                //exchange and *-1 (special case 0 stays 0 instead of -0 for consistency)
                double newMinVal = (((LiteralOp) max).getDoubleValue() == 0) ? 0 : (-1 * ((LiteralOp) max).getDoubleValue());
                double newMaxVal = (((LiteralOp) min).getDoubleValue() == 0) ? 0 : (-1 * ((LiteralOp) min).getDoubleValue());
                Hop newMin = new LiteralOp(newMinVal);
                Hop newMax = new LiteralOp(newMaxVal);
                HopRewriteUtils.removeChildReferenceByPos(inputGen, min, ixMin);
                HopRewriteUtils.addChildReference(inputGen, newMin, ixMin);
                HopRewriteUtils.removeChildReferenceByPos(inputGen, max, ixMax);
                HopRewriteUtils.addChildReference(inputGen, newMax, ixMax);
                //rewire all parents (avoid anomalies with replicated datagen)
                List<Hop> parents = new ArrayList<Hop>(bop.getParent());
                for (Hop p : parents) HopRewriteUtils.replaceChildReference(p, bop, inputGen);
                hi = inputGen;
                LOG.debug("Applied fuseDatagenAndMinusOperation (line " + bop.getBeginLine() + ").");
            }
        }
    }
    return hi;
}
Also used : DataGenOp(org.apache.sysml.hops.DataGenOp) Hop(org.apache.sysml.hops.Hop) ArrayList(java.util.ArrayList) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 9 with LiteralOp

use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.

the class RewriteCommonSubexpressionElimination method rule_CommonSubexpressionElimination_MergeLeafs.

private int rule_CommonSubexpressionElimination_MergeLeafs(Hop hop, HashMap<String, Hop> dataops, HashMap<String, Hop> literalops) throws HopsException {
    int ret = 0;
    if (hop.isVisited())
        return ret;
    if (//LEAF NODE
    hop.getInput().isEmpty()) {
        if (hop instanceof LiteralOp) {
            String key = hop.getValueType() + "_" + hop.getName();
            if (!literalops.containsKey(key))
                literalops.put(key, hop);
        } else if (hop instanceof DataOp && ((DataOp) hop).isRead()) {
            if (!dataops.containsKey(hop.getName()))
                dataops.put(hop.getName(), hop);
        }
    } else //INNER NODE
    {
        //merge leaf nodes (data, literal)
        for (int i = 0; i < hop.getInput().size(); i++) {
            Hop hi = hop.getInput().get(i);
            String litKey = hi.getValueType() + "_" + hi.getName();
            if (hi instanceof DataOp && ((DataOp) hi).isRead() && dataops.containsKey(hi.getName())) {
                //replace child node ref
                Hop tmp = dataops.get(hi.getName());
                if (tmp != hi) {
                    //if required
                    tmp.getParent().add(hop);
                    hop.getInput().set(i, tmp);
                    ret++;
                }
            } else if (hi instanceof LiteralOp && literalops.containsKey(litKey)) {
                Hop tmp = literalops.get(litKey);
                //replace child node ref
                if (tmp != hi) {
                    //if required
                    tmp.getParent().add(hop);
                    hop.getInput().set(i, tmp);
                    ret++;
                }
            }
            //recursive invocation (direct return on merged nodes)
            ret += rule_CommonSubexpressionElimination_MergeLeafs(hi, dataops, literalops);
        }
    }
    hop.setVisited();
    return ret;
}
Also used : Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) DataOp(org.apache.sysml.hops.DataOp)

Example 10 with LiteralOp

use of org.apache.sysml.hops.LiteralOp in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedSquaredLoss.

/**
	 * Searches for weighted squared loss expressions and replaces them with a quaternary operator. 
	 * Currently, this search includes the following three patterns:
	 * 1) sum (W * (X - U %*% t(V)) ^ 2) (post weighting)
	 * 2) sum ((X - W * (U %*% t(V))) ^ 2) (pre weighting)
	 * 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting)
	 * 
	 * NOTE: We include transpose into the pattern because during runtime we need to compute
	 * U%*% t(V) pointwise; having V and not t(V) at hand allows for a cache-friendly implementation
	 * without additional memory requirements for internal transpose.
	 * 
	 * This rewrite is conceptually a static rewrite; however, the current MR runtime only supports
	 * U/V factors of rank up to the blocksize (1000). We enforce this contraint here during the general
	 * rewrite because this is an uncommon case. Also, the intention is to remove this constaint as soon
	 * as we generalized the runtime or hop/lop compilation. 
	 * 
	 * @param parent parent high-level operator
	 * @param hi high-level operator
	 * @param pos position
	 * @return high-level operator
	 * @throws HopsException if HopsException occurs
	 */
private Hop simplifyWeightedSquaredLoss(Hop parent, Hop hi, int pos) throws HopsException {
    //NOTE: there might be also a general simplification without custom operator
    //via (X-UVt)^2 -> X^2 - 2X*UVt + UVt^2
    Hop hnew = null;
    if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getDirection() == Direction.RowCol && //all patterns rooted by sum()
    ((AggUnaryOp) hi).getOp() == AggOp.SUM && //all patterns subrooted by binary op
    hi.getInput().get(0) instanceof BinaryOp && //not applied for vector-vector mult
    hi.getInput().get(0).getDim2() > 1) {
        BinaryOp bop = (BinaryOp) hi.getInput().get(0);
        boolean appliedPattern = false;
        //alternative pattern: sum (W * (U %*% t(V) - X) ^ 2)
        if (bop.getOp() == OpOp2.MULT && HopRewriteUtils.isBinary(bop.getInput().get(1), OpOp2.POW) && bop.getInput().get(0).getDataType() == DataType.MATRIX && //prevent mv
        HopRewriteUtils.isEqualSize(bop.getInput().get(0), bop.getInput().get(1)) && bop.getInput().get(1).getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) bop.getInput().get(1).getInput().get(1)) == 2) {
            Hop W = bop.getInput().get(0);
            //(X - U %*% t(V))
            Hop tmp = bop.getInput().get(1).getInput().get(0);
            if (HopRewriteUtils.isBinary(tmp, OpOp2.MINUS) && //prevent mv	
            HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) && tmp.getInput().get(0).getDataType() == DataType.MATRIX) {
                //a) sum (W * (X - U %*% t(V)) ^ 2)
                int uvIndex = -1;
                if (//ba gurantees matrices
                tmp.getInput().get(1) instanceof AggBinaryOp && //BLOCKSIZE CONSTRAINT
                HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0), true)) {
                    uvIndex = 1;
                } else //b) sum (W * (U %*% t(V) - X) ^ 2)
                if (//ba gurantees matrices
                tmp.getInput().get(0) instanceof AggBinaryOp && //BLOCKSIZE CONSTRAINT
                HopRewriteUtils.isSingleBlock(tmp.getInput().get(0).getInput().get(0), true)) {
                    uvIndex = 0;
                }
                if (//rewrite match
                uvIndex >= 0) {
                    Hop X = tmp.getInput().get((uvIndex == 0) ? 1 : 0);
                    Hop U = tmp.getInput().get(uvIndex).getInput().get(0);
                    Hop V = tmp.getInput().get(uvIndex).getInput().get(1);
                    if (!HopRewriteUtils.isTransposeOperation(V)) {
                        V = HopRewriteUtils.createTranspose(V);
                    } else {
                        V = V.getInput().get(0);
                    }
                    //handle special case of post_nz
                    if (HopRewriteUtils.isNonZeroIndicator(W, X)) {
                        W = new LiteralOp(1);
                    }
                    //construct quaternary hop
                    hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, true);
                    HopRewriteUtils.setOutputParametersForScalar(hnew);
                    appliedPattern = true;
                    LOG.debug("Applied simplifyWeightedSquaredLoss1" + uvIndex + " (line " + hi.getBeginLine() + ")");
                }
            }
        }
        //alternative pattern: sum ((W * (U %*% t(V)) - X) ^ 2)
        if (!appliedPattern && bop.getOp() == OpOp2.POW && bop.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) bop.getInput().get(1)) == 2 && HopRewriteUtils.isBinary(bop.getInput().get(0), OpOp2.MINUS) && bop.getInput().get(0).getDataType() == DataType.MATRIX && //prevent mv
        HopRewriteUtils.isEqualSize(bop.getInput().get(0).getInput().get(0), bop.getInput().get(0).getInput().get(1)) && bop.getInput().get(0).getInput().get(0).getDataType() == DataType.MATRIX) {
            Hop lleft = bop.getInput().get(0).getInput().get(0);
            Hop lright = bop.getInput().get(0).getInput().get(1);
            //a) sum ((X - W * (U %*% t(V))) ^ 2)
            int wuvIndex = -1;
            if (lright instanceof BinaryOp && lright.getInput().get(1) instanceof AggBinaryOp) {
                wuvIndex = 1;
            } else //b) sum ((W * (U %*% t(V)) - X) ^ 2)
            if (lleft instanceof BinaryOp && lleft.getInput().get(1) instanceof AggBinaryOp) {
                wuvIndex = 0;
            }
            if (//rewrite match
            wuvIndex >= 0) {
                Hop X = bop.getInput().get(0).getInput().get((wuvIndex == 0) ? 1 : 0);
                //(W * (U %*% t(V)))
                Hop tmp = bop.getInput().get(0).getInput().get(wuvIndex);
                if (((BinaryOp) tmp).getOp() == OpOp2.MULT && tmp.getInput().get(0).getDataType() == DataType.MATRIX && //prevent mv
                HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) && //BLOCKSIZE CONSTRAINT
                HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0), true)) {
                    Hop W = tmp.getInput().get(0);
                    Hop U = tmp.getInput().get(1).getInput().get(0);
                    Hop V = tmp.getInput().get(1).getInput().get(1);
                    if (!HopRewriteUtils.isTransposeOperation(V)) {
                        V = HopRewriteUtils.createTranspose(V);
                    } else {
                        V = V.getInput().get(0);
                    }
                    hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, false);
                    HopRewriteUtils.setOutputParametersForScalar(hnew);
                    appliedPattern = true;
                    LOG.debug("Applied simplifyWeightedSquaredLoss2" + wuvIndex + " (line " + hi.getBeginLine() + ")");
                }
            }
        }
        //alternative pattern: sum (((U %*% t(V)) - X) ^ 2)
        if (!appliedPattern && bop.getOp() == OpOp2.POW && bop.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp) bop.getInput().get(1)) == 2 && HopRewriteUtils.isBinary(bop.getInput().get(0), OpOp2.MINUS) && bop.getInput().get(0).getDataType() == DataType.MATRIX && //prevent mv
        HopRewriteUtils.isEqualSize(bop.getInput().get(0).getInput().get(0), bop.getInput().get(0).getInput().get(1)) && bop.getInput().get(0).getInput().get(0).getDataType() == DataType.MATRIX) {
            Hop lleft = bop.getInput().get(0).getInput().get(0);
            Hop lright = bop.getInput().get(0).getInput().get(1);
            //a) sum ((X - (U %*% t(V))) ^ 2)
            int uvIndex = -1;
            if (//ba gurantees matrices
            lright instanceof AggBinaryOp && //BLOCKSIZE CONSTRAINT
            HopRewriteUtils.isSingleBlock(lright.getInput().get(0), true)) {
                uvIndex = 1;
            } else //b) sum (((U %*% t(V)) - X) ^ 2)
            if (//ba gurantees matrices
            lleft instanceof AggBinaryOp && //BLOCKSIZE CONSTRAINT
            HopRewriteUtils.isSingleBlock(lleft.getInput().get(0), true)) {
                uvIndex = 0;
            }
            if (//rewrite match
            uvIndex >= 0) {
                Hop X = bop.getInput().get(0).getInput().get((uvIndex == 0) ? 1 : 0);
                //(U %*% t(V))
                Hop tmp = bop.getInput().get(0).getInput().get(uvIndex);
                //no weighting 
                Hop W = new LiteralOp(1);
                Hop U = tmp.getInput().get(0);
                Hop V = tmp.getInput().get(1);
                if (!HopRewriteUtils.isTransposeOperation(V)) {
                    V = HopRewriteUtils.createTranspose(V);
                } else {
                    V = V.getInput().get(0);
                }
                hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WSLOSS, X, U, V, W, false);
                HopRewriteUtils.setOutputParametersForScalar(hnew);
                appliedPattern = true;
                LOG.debug("Applied simplifyWeightedSquaredLoss3" + uvIndex + " (line " + hi.getBeginLine() + ")");
            }
        }
    }
    //relink new hop into original position
    if (hnew != null) {
        HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
        hi = hnew;
    }
    return hi;
}
Also used : QuaternaryOp(org.apache.sysml.hops.QuaternaryOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Aggregations

LiteralOp (org.apache.sysml.hops.LiteralOp)81 Hop (org.apache.sysml.hops.Hop)72 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)29 BinaryOp (org.apache.sysml.hops.BinaryOp)24 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)23 UnaryOp (org.apache.sysml.hops.UnaryOp)23 DataOp (org.apache.sysml.hops.DataOp)18 IndexingOp (org.apache.sysml.hops.IndexingOp)15 DataGenOp (org.apache.sysml.hops.DataGenOp)12 LeftIndexingOp (org.apache.sysml.hops.LeftIndexingOp)11 HashMap (java.util.HashMap)9 ArrayList (java.util.ArrayList)8 ReorgOp (org.apache.sysml.hops.ReorgOp)8 MatrixObject (org.apache.sysml.runtime.controlprogram.caching.MatrixObject)8 HopsException (org.apache.sysml.hops.HopsException)6 ParameterizedBuiltinOp (org.apache.sysml.hops.ParameterizedBuiltinOp)6 TernaryOp (org.apache.sysml.hops.TernaryOp)6 DataIdentifier (org.apache.sysml.parser.DataIdentifier)6 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)6 QuaternaryOp (org.apache.sysml.hops.QuaternaryOp)5