Search in sources :

Example 1 with MMulTranspose

use of org.nd4j.linalg.api.blas.params.MMulTranspose in project nd4j by deeplearning4j.

the class TensorMmul method initFromOnnx.

@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
    val isTransposeA = !attributesForNode.containsKey("transA") ? false : attributesForNode.get("transA").getI() > 0;
    val isTransposeB = !attributesForNode.containsKey("transB") ? false : attributesForNode.get("transB").getI() > 0;
    MMulTranspose mMulTranspose = MMulTranspose.builder().transposeA(isTransposeA).transposeB(isTransposeB).build();
    this.mMulTranspose = mMulTranspose;
}
Also used : lombok.val(lombok.val) MMulTranspose(org.nd4j.linalg.api.blas.params.MMulTranspose)

Example 2 with MMulTranspose

use of org.nd4j.linalg.api.blas.params.MMulTranspose in project nd4j by deeplearning4j.

the class Mmul method initFromTensorFlow.

@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
    val isTransposeA = attributesForNode.get("transpose_a").getB();
    val isTransposeB = attributesForNode.get("transpose_b").getB();
    MMulTranspose mMulTranspose = MMulTranspose.builder().transposeA(isTransposeA).transposeB(isTransposeB).build();
    this.mMulTranspose = mMulTranspose;
    val args = args();
    for (val arg : args) {
        if (sameDiff.isPlaceHolder(arg.getVarName()) || arg.getShape() == null) {
            sameDiff.addPropertyToResolve(this, arg.getVarName());
        }
    }
}
Also used : lombok.val(lombok.val) MMulTranspose(org.nd4j.linalg.api.blas.params.MMulTranspose)

Example 3 with MMulTranspose

use of org.nd4j.linalg.api.blas.params.MMulTranspose in project nd4j by deeplearning4j.

the class Mmul method initFromOnnx.

@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
    val isTransposeA = !attributesForNode.containsKey("transA") ? false : attributesForNode.get("transA").getI() > 0;
    val isTransposeB = !attributesForNode.containsKey("transB") ? false : attributesForNode.get("transB").getI() > 0;
    MMulTranspose mMulTranspose = MMulTranspose.builder().transposeA(isTransposeA).transposeB(isTransposeB).build();
    this.mMulTranspose = mMulTranspose;
}
Also used : lombok.val(lombok.val) MMulTranspose(org.nd4j.linalg.api.blas.params.MMulTranspose)

Example 4 with MMulTranspose

use of org.nd4j.linalg.api.blas.params.MMulTranspose in project nd4j by deeplearning4j.

the class SameDiffTests method testMmulWithTranspose.

@Test
public void testMmulWithTranspose() {
    for (int i : new int[] { 2, 1 }) {
        System.out.println("i = " + i);
        // To [1,3] or [2,3]
        INDArray first = Nd4j.linspace(1, 3 * i, 3 * i).reshape('c', i, 3);
        // To [1,4] or [2,4]
        INDArray second = Nd4j.linspace(4, 4 + 4 * i, 4 * i).reshape('c', i, 4);
        System.out.println("Shapes: " + Arrays.toString(first.shape()) + "\t" + Arrays.toString(second.shape()));
        SameDiff sd = SameDiff.create();
        SDVariable f = sd.var("in1", first);
        SDVariable s = sd.var("in2", second);
        MMulTranspose mt = MMulTranspose.builder().transposeA(true).transposeB(false).transposeResult(false).a(first).b(second).build();
        SDVariable mmul = sd.f().mmul(f, s, mt);
        sd.updateVariableNameAndReference(mmul, "mmul");
        INDArray out = sd.execAndEndResult();
        INDArray exp = first.transpose().mmul(second);
        assertEquals(exp, out);
        System.out.println("----- Finished: i = " + i + " ------");
    }
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) MMulTranspose(org.nd4j.linalg.api.blas.params.MMulTranspose) Test(org.junit.Test)

Example 5 with MMulTranspose

use of org.nd4j.linalg.api.blas.params.MMulTranspose in project nd4j by deeplearning4j.

the class Nd4jTestsC method testMmulWithTranspose.

@Test
public void testMmulWithTranspose() {
    INDArray arr = Nd4j.linspace(1, 4, 4).reshape(2, 2);
    INDArray arr2 = Nd4j.linspace(1, 4, 4).reshape(2, 2).transpose();
    INDArray arrTransposeAssertion = arr.transpose().mmul(arr2);
    MMulTranspose mMulTranspose = MMulTranspose.builder().transposeA(true).a(arr).b(arr2).build();
    INDArray testResult = arr.mmul(arr2, mMulTranspose);
    assertEquals(arrTransposeAssertion, testResult);
    INDArray bTransposeAssertion = arr.mmul(arr2.transpose());
    mMulTranspose = MMulTranspose.builder().transposeB(true).a(arr).b(arr2).build();
    INDArray bTest = arr.mmul(arr2, mMulTranspose);
    assertEquals(bTransposeAssertion, bTest);
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) MMulTranspose(org.nd4j.linalg.api.blas.params.MMulTranspose) Test(org.junit.Test)

Aggregations

MMulTranspose (org.nd4j.linalg.api.blas.params.MMulTranspose)7 lombok.val (lombok.val)4 Test (org.junit.Test)3 INDArray (org.nd4j.linalg.api.ndarray.INDArray)3 DynamicCustomOp (org.nd4j.linalg.api.ops.DynamicCustomOp)1 Mmul (org.nd4j.linalg.api.ops.impl.accum.Mmul)1