use of org.nd4j.linalg.api.blas.params.MMulTranspose in project nd4j by deeplearning4j.
the class TensorMmul method initFromOnnx.
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
val isTransposeA = !attributesForNode.containsKey("transA") ? false : attributesForNode.get("transA").getI() > 0;
val isTransposeB = !attributesForNode.containsKey("transB") ? false : attributesForNode.get("transB").getI() > 0;
MMulTranspose mMulTranspose = MMulTranspose.builder().transposeA(isTransposeA).transposeB(isTransposeB).build();
this.mMulTranspose = mMulTranspose;
}
use of org.nd4j.linalg.api.blas.params.MMulTranspose in project nd4j by deeplearning4j.
the class Mmul method initFromTensorFlow.
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
val isTransposeA = attributesForNode.get("transpose_a").getB();
val isTransposeB = attributesForNode.get("transpose_b").getB();
MMulTranspose mMulTranspose = MMulTranspose.builder().transposeA(isTransposeA).transposeB(isTransposeB).build();
this.mMulTranspose = mMulTranspose;
val args = args();
for (val arg : args) {
if (sameDiff.isPlaceHolder(arg.getVarName()) || arg.getShape() == null) {
sameDiff.addPropertyToResolve(this, arg.getVarName());
}
}
}
use of org.nd4j.linalg.api.blas.params.MMulTranspose in project nd4j by deeplearning4j.
the class Mmul method initFromOnnx.
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
val isTransposeA = !attributesForNode.containsKey("transA") ? false : attributesForNode.get("transA").getI() > 0;
val isTransposeB = !attributesForNode.containsKey("transB") ? false : attributesForNode.get("transB").getI() > 0;
MMulTranspose mMulTranspose = MMulTranspose.builder().transposeA(isTransposeA).transposeB(isTransposeB).build();
this.mMulTranspose = mMulTranspose;
}
use of org.nd4j.linalg.api.blas.params.MMulTranspose in project nd4j by deeplearning4j.
the class SameDiffTests method testMmulWithTranspose.
@Test
public void testMmulWithTranspose() {
for (int i : new int[] { 2, 1 }) {
System.out.println("i = " + i);
// To [1,3] or [2,3]
INDArray first = Nd4j.linspace(1, 3 * i, 3 * i).reshape('c', i, 3);
// To [1,4] or [2,4]
INDArray second = Nd4j.linspace(4, 4 + 4 * i, 4 * i).reshape('c', i, 4);
System.out.println("Shapes: " + Arrays.toString(first.shape()) + "\t" + Arrays.toString(second.shape()));
SameDiff sd = SameDiff.create();
SDVariable f = sd.var("in1", first);
SDVariable s = sd.var("in2", second);
MMulTranspose mt = MMulTranspose.builder().transposeA(true).transposeB(false).transposeResult(false).a(first).b(second).build();
SDVariable mmul = sd.f().mmul(f, s, mt);
sd.updateVariableNameAndReference(mmul, "mmul");
INDArray out = sd.execAndEndResult();
INDArray exp = first.transpose().mmul(second);
assertEquals(exp, out);
System.out.println("----- Finished: i = " + i + " ------");
}
}
use of org.nd4j.linalg.api.blas.params.MMulTranspose in project nd4j by deeplearning4j.
the class Nd4jTestsC method testMmulWithTranspose.
@Test
public void testMmulWithTranspose() {
INDArray arr = Nd4j.linspace(1, 4, 4).reshape(2, 2);
INDArray arr2 = Nd4j.linspace(1, 4, 4).reshape(2, 2).transpose();
INDArray arrTransposeAssertion = arr.transpose().mmul(arr2);
MMulTranspose mMulTranspose = MMulTranspose.builder().transposeA(true).a(arr).b(arr2).build();
INDArray testResult = arr.mmul(arr2, mMulTranspose);
assertEquals(arrTransposeAssertion, testResult);
INDArray bTransposeAssertion = arr.mmul(arr2.transpose());
mMulTranspose = MMulTranspose.builder().transposeB(true).a(arr).b(arr2).build();
INDArray bTest = arr.mmul(arr2, mMulTranspose);
assertEquals(bTransposeAssertion, bTest);
}
Aggregations