Search in sources :

Example 81 with Hop

use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyWeightedSigmoidMMChains.

private static Hop simplifyWeightedSigmoidMMChains(Hop parent, Hop hi, int pos) {
    Hop hnew = null;
    if (// all patterns subrooted by W *
    HopRewriteUtils.isBinary(hi, OpOp2.MULT) && // not applied for vector-vector mult
    hi.getDim2() > 1 && // prevent mv
    HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) && hi.getInput().get(0).getDataType() == DataType.MATRIX && // sigmoid/log
    hi.getInput().get(1) instanceof UnaryOp) {
        UnaryOp uop = (UnaryOp) hi.getInput().get(1);
        boolean appliedPattern = false;
        // Pattern 1) W * sigmoid(Y%*%t(X)) (basic)
        if (uop.getOp() == OpOp1.SIGMOID && uop.getInput().get(0) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(0), true)) {
            Hop W = hi.getInput().get(0);
            Hop Y = uop.getInput().get(0).getInput().get(0);
            Hop tX = uop.getInput().get(0).getInput().get(1);
            if (!HopRewriteUtils.isTransposeOperation(tX)) {
                tX = HopRewriteUtils.createTranspose(tX);
            } else
                tX = tX.getInput().get(0);
            hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, false, false);
            hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
            hnew.refreshSizeInformation();
            appliedPattern = true;
            LOG.debug("Applied simplifyWeightedSigmoid1 (line " + hi.getBeginLine() + ")");
        }
        // Pattern 2) W * sigmoid(-(Y%*%t(X))) (minus)
        if (!appliedPattern && uop.getOp() == OpOp1.SIGMOID && HopRewriteUtils.isBinary(uop.getInput().get(0), OpOp2.MINUS) && uop.getInput().get(0).getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) uop.getInput().get(0).getInput().get(0)) == 0 && uop.getInput().get(0).getInput().get(1) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(1).getInput().get(0), true)) {
            Hop W = hi.getInput().get(0);
            Hop Y = uop.getInput().get(0).getInput().get(1).getInput().get(0);
            Hop tX = uop.getInput().get(0).getInput().get(1).getInput().get(1);
            if (!HopRewriteUtils.isTransposeOperation(tX)) {
                tX = HopRewriteUtils.createTranspose(tX);
            } else
                tX = tX.getInput().get(0);
            hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, false, true);
            hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
            hnew.refreshSizeInformation();
            appliedPattern = true;
            LOG.debug("Applied simplifyWeightedSigmoid2 (line " + hi.getBeginLine() + ")");
        }
        // Pattern 3) W * log(sigmoid(Y%*%t(X))) (log)
        if (!appliedPattern && uop.getOp() == OpOp1.LOG && HopRewriteUtils.isUnary(uop.getInput().get(0), OpOp1.SIGMOID) && uop.getInput().get(0).getInput().get(0) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(0).getInput().get(0), true)) {
            Hop W = hi.getInput().get(0);
            Hop Y = uop.getInput().get(0).getInput().get(0).getInput().get(0);
            Hop tX = uop.getInput().get(0).getInput().get(0).getInput().get(1);
            if (!HopRewriteUtils.isTransposeOperation(tX)) {
                tX = HopRewriteUtils.createTranspose(tX);
            } else
                tX = tX.getInput().get(0);
            hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, true, false);
            hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
            hnew.refreshSizeInformation();
            appliedPattern = true;
            LOG.debug("Applied simplifyWeightedSigmoid3 (line " + hi.getBeginLine() + ")");
        }
        // Pattern 4) W * log(sigmoid(-(Y%*%t(X)))) (log_minus)
        if (!appliedPattern && uop.getOp() == OpOp1.LOG && HopRewriteUtils.isUnary(uop.getInput().get(0), OpOp1.SIGMOID) && HopRewriteUtils.isBinary(uop.getInput().get(0).getInput().get(0), OpOp2.MINUS)) {
            BinaryOp bop = (BinaryOp) uop.getInput().get(0).getInput().get(0);
            if (bop.getInput().get(0) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) bop.getInput().get(0)) == 0 && bop.getInput().get(1) instanceof AggBinaryOp && HopRewriteUtils.isSingleBlock(bop.getInput().get(1).getInput().get(0), true)) {
                Hop W = hi.getInput().get(0);
                Hop Y = bop.getInput().get(1).getInput().get(0);
                Hop tX = bop.getInput().get(1).getInput().get(1);
                if (!HopRewriteUtils.isTransposeOperation(tX)) {
                    tX = HopRewriteUtils.createTranspose(tX);
                } else
                    tX = tX.getInput().get(0);
                hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, OpOp4.WSIGMOID, W, Y, tX, true, true);
                hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
                hnew.refreshSizeInformation();
                appliedPattern = true;
                LOG.debug("Applied simplifyWeightedSigmoid4 (line " + hi.getBeginLine() + ")");
            }
        }
    }
    // relink new hop into original position
    if (hnew != null) {
        HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
        hi = hnew;
    }
    return hi;
}
Also used : QuaternaryOp(org.apache.sysml.hops.QuaternaryOp) AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) UnaryOp(org.apache.sysml.hops.UnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) BinaryOp(org.apache.sysml.hops.BinaryOp)

Example 82 with Hop

use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyIdentityRepMatrixMult.

private static Hop simplifyIdentityRepMatrixMult(Hop parent, Hop hi, int pos) {
    if (// X%*%Y -> X, if y is matrix(1,1,1)
    HopRewriteUtils.isMatrixMultiply(hi)) {
        Hop left = hi.getInput().get(0);
        Hop right = hi.getInput().get(1);
        // X %*% y -> X
        if (// scalar right
        HopRewriteUtils.isDimsKnown(right) && right.getDim1() == 1 && right.getDim2() == 1 && right instanceof DataGenOp && ((DataGenOp) right).getOp() == DataGenMethod.RAND && // matrix(1,)
        ((DataGenOp) right).hasConstantValue(1.0)) {
            HopRewriteUtils.replaceChildReference(parent, hi, left, pos);
            hi = left;
            LOG.debug("Applied simplifyIdentiyMatrixMult");
        }
    }
    return hi;
}
Also used : DataGenOp(org.apache.sysml.hops.DataGenOp) Hop(org.apache.sysml.hops.Hop)

Example 83 with Hop

use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyMatrixMultDiag.

private static Hop simplifyMatrixMultDiag(Hop parent, Hop hi, int pos) {
    Hop hnew = null;
    if (// X%*%Y
    HopRewriteUtils.isMatrixMultiply(hi)) {
        Hop left = hi.getInput().get(0);
        Hop right = hi.getInput().get(1);
        // previously rep required for the second case: diag(X) %*% Y -> (X%*%ones) * Y
        if (// left diag
        left instanceof ReorgOp && ((ReorgOp) left).getOp() == ReOrgOp.DIAG && HopRewriteUtils.isDimsKnown(left) && // diagV2M
        left.getDim2() > 1) {
            if (// right column vector
            right.getDim2() == 1) {
                // create binary operation over input and right
                // diag input
                Hop input = left.getInput().get(0);
                hnew = HopRewriteUtils.createBinary(input, right, OpOp2.MULT);
                LOG.debug("Applied simplifyMatrixMultDiag1");
            } else if (// multi column vector
            right.getDim2() > 1) {
                // create binary operation over input and right; in contrast to above rewrite,
                // we need to switch the order because MV binary cell operations require vector on the right
                // diag input
                Hop input = left.getInput().get(0);
                hnew = HopRewriteUtils.createBinary(right, input, OpOp2.MULT);
                // NOTE: previously to MV binary cell operations we replicated the left
                // (if moderate number of columns: 2), but this is no longer required
                LOG.debug("Applied simplifyMatrixMultDiag2");
            }
        }
    // notes: similar rewrites would be possible for the right side as well, just transposed into the right alignment
    }
    // if one of the above rewrites applied
    if (hnew != null) {
        // add mult to parent
        HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
        HopRewriteUtils.cleanupUnreferenced(hi);
        hi = hnew;
    }
    return hi;
}
Also used : Hop(org.apache.sysml.hops.Hop) ReorgOp(org.apache.sysml.hops.ReorgOp)

Example 84 with Hop

use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyEmptySortOperation.

private static Hop simplifyEmptySortOperation(Hop parent, Hop hi, int pos) {
    // order(X, indexreturn=TRUE) -> seq(1,nrow(X),1)
    if (hi instanceof ReorgOp && ((ReorgOp) hi).getOp() == ReOrgOp.SORT) {
        ReorgOp rhi = (ReorgOp) hi;
        Hop input = rhi.getInput().get(0);
        if (// empty input
        HopRewriteUtils.isEmpty(input)) {
            // reorg-operation-specific rewrite
            Hop hnew = null;
            boolean ixret = false;
            if (// index return known
            rhi.getInput().get(3) instanceof LiteralOp) {
                ixret = HopRewriteUtils.getBooleanValue((LiteralOp) rhi.getInput().get(3));
                if (ixret)
                    hnew = HopRewriteUtils.createSeqDataGenOp(input);
                else
                    hnew = HopRewriteUtils.createDataGenOp(input, 0);
            }
            // modify dag if one of the above rules applied
            if (hnew != null) {
                HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
                hi = hnew;
                LOG.debug("Applied simplifyEmptySortOperation (indexreturn=" + ixret + ").");
            }
        }
    }
    return hi;
}
Also used : ReorgOp(org.apache.sysml.hops.ReorgOp) Hop(org.apache.sysml.hops.Hop) LiteralOp(org.apache.sysml.hops.LiteralOp)

Example 85 with Hop

use of org.apache.sysml.hops.Hop in project incubator-systemml by apache.

the class RewriteAlgebraicSimplificationDynamic method simplifyColSumsMVMult.

private static Hop simplifyColSumsMVMult(Hop parent, Hop hi, int pos) {
    // removed by other rewrite if unnecessary, i.e., if Y==t(Z)
    if (hi instanceof AggUnaryOp) {
        AggUnaryOp uhi = (AggUnaryOp) hi;
        Hop input = uhi.getInput().get(0);
        if (// colsums
        uhi.getOp() == AggOp.SUM && uhi.getDirection() == Direction.Col && // b(*)
        HopRewriteUtils.isBinary(input, OpOp2.MULT)) {
            Hop left = input.getInput().get(0);
            Hop right = input.getInput().get(1);
            if (left.getDim1() > 1 && left.getDim2() > 1 && right.getDim1() > 1 && // MV (col vector)
            right.getDim2() == 1) {
                // create new operators
                ReorgOp trans = HopRewriteUtils.createTranspose(right);
                AggBinaryOp mmult = HopRewriteUtils.createMatrixMultiply(trans, left);
                // relink new child
                HopRewriteUtils.replaceChildReference(parent, hi, mmult, pos);
                HopRewriteUtils.cleanupUnreferenced(uhi, input);
                hi = mmult;
                LOG.debug("Applied simplifyColSumsMVMult");
            }
        }
    }
    return hi;
}
Also used : AggUnaryOp(org.apache.sysml.hops.AggUnaryOp) AggBinaryOp(org.apache.sysml.hops.AggBinaryOp) Hop(org.apache.sysml.hops.Hop) ReorgOp(org.apache.sysml.hops.ReorgOp)

Aggregations

Hop (org.apache.sysml.hops.Hop)307 LiteralOp (org.apache.sysml.hops.LiteralOp)94 AggBinaryOp (org.apache.sysml.hops.AggBinaryOp)65 BinaryOp (org.apache.sysml.hops.BinaryOp)63 ArrayList (java.util.ArrayList)61 AggUnaryOp (org.apache.sysml.hops.AggUnaryOp)61 HashMap (java.util.HashMap)44 DataOp (org.apache.sysml.hops.DataOp)41 UnaryOp (org.apache.sysml.hops.UnaryOp)41 HashSet (java.util.HashSet)39 ReorgOp (org.apache.sysml.hops.ReorgOp)32 MemoTableEntry (org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry)28 StatementBlock (org.apache.sysml.parser.StatementBlock)28 IndexingOp (org.apache.sysml.hops.IndexingOp)24 ForStatementBlock (org.apache.sysml.parser.ForStatementBlock)23 WhileStatementBlock (org.apache.sysml.parser.WhileStatementBlock)23 IfStatementBlock (org.apache.sysml.parser.IfStatementBlock)22 DataGenOp (org.apache.sysml.hops.DataGenOp)21 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)21 HopsException (org.apache.sysml.hops.HopsException)18