use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class GradCheckTransforms method invertPermutation.
@Test
public void invertPermutation() {
SameDiff sd = SameDiff.create();
INDArray ia = Nd4j.create(new float[] { 3, 4, 0, 2, 1 });
INDArray expOut = Nd4j.create(new float[] { 2, 4, 3, 0, 1 });
SDVariable input = sd.var("in", new int[] { 1, 5 });
sd.associateArrayWithVariable(ia, input);
SDVariable out = sd.invertPermutation(input);
sd.exec();
assert expOut.equals(out.getArr());
}
use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class TensorFlowImportTest method importGraph1.
@Test
@Ignore
public void importGraph1() throws Exception {
SameDiff graph = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/max_add_2.pb.txt").getInputStream());
assertNotNull(graph);
assertEquals(2, graph.variableMap().size());
SDVariable var0 = graph.variableMap().get("zeros");
SDVariable var1 = graph.variableMap().get("ones");
assertNotNull(var0);
assertNotNull(var1);
assertNotNull(var0.getArr());
assertNotNull(var1.getArr());
assertEquals(0.0, var0.getArr().sumNumber().doubleValue(), 1e-5);
assertEquals(12.0, var1.getArr().sumNumber().doubleValue(), 1e-5);
}
use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class GradCheckLoss method testLossWeights2d.
@Test
public void testLossWeights2d() {
String[] weightTypes = new String[] { "none", "per-example", "per-output", "per-example-output" };
Nd4j.getRandom().setSeed(12345);
int nOut = 4;
int minibatch = 10;
for (String weightType : weightTypes) {
for (boolean binary : new boolean[] { true, false }) {
// Binary mask (like DL4J) or arbitrary weights?
int[] weightShape;
switch(weightType) {
case "none":
weightShape = null;
break;
case "per-example":
weightShape = new int[] { minibatch, 1 };
break;
case "per-output":
weightShape = new int[] { 1, nOut };
break;
case "per-example-output":
weightShape = new int[] { minibatch, nOut };
break;
default:
throw new RuntimeException("Unknown type: " + weightType);
}
INDArray weightArr = null;
if (!"none".equals(weightType)) {
if (binary) {
weightArr = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(weightShape), 0.5));
} else {
weightArr = Nd4j.rand(weightShape).muli(2.0);
}
}
for (LossFunctions.Reduction reduction : new LossFunctions.Reduction[] { LossFunctions.Reduction.MEAN_BY_COUNT, LossFunctions.Reduction.MEAN_BY_WEIGHT, LossFunctions.Reduction.SUM }) {
for (String fn : new String[] { "mse", "l1", "l2", "mcxent" }) {
SameDiff sd = SameDiff.create();
SDVariable input = sd.var("in", new int[] { -1, nOut });
SDVariable labels = sd.var("labels", new int[] { -1, nOut });
SDVariable weight = null;
if (!"none".equals(weightType)) {
weight = sd.var("weights", weightArr);
}
INDArray inputArr = Nd4j.randn(minibatch, nOut).muli(100);
INDArray labelsArr = Nd4j.randn(minibatch, nOut).muli(100);
LossInfo lossInfo;
switch(fn) {
case "mse":
lossInfo = LossFunctions.mse("out", input, labels, weight, reduction, 1);
break;
case "l1":
lossInfo = LossFunctions.l1("out", input, labels, weight, reduction, 1);
// L1 = sum abs error
break;
case "l2":
lossInfo = LossFunctions.l2("out", input, labels, weight, reduction, 1);
// L2 = sum squared error
break;
case "mcxent":
lossInfo = LossFunctions.mcxent("out", input, labels, weight, reduction, 1);
// mcxent = sum label * log(prob)
break;
default:
throw new RuntimeException();
}
String msg = "lossFn=" + fn + ", reduction=" + reduction + ", weightType=" + weightType + ", binaryWeight=" + binary;
log.info("*** Starting test: " + msg);
sd.associateArrayWithVariable(inputArr, input);
sd.associateArrayWithVariable(labelsArr, labels);
if (weight != null) {
sd.associateArrayWithVariable(weightArr, weight);
}
INDArray out = sd.execAndEndResult();
assertEquals(1, out.length());
boolean ok = GradCheckUtil.checkGradients(sd);
assertTrue(msg, ok);
}
}
}
}
}
use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class GradCheckLoss method testLossSimple2d.
@Test
public void testLossSimple2d() {
Nd4j.getRandom().setSeed(12345);
for (String fn : new String[] { "mse", "l1", "l2", "mcxent" }) {
for (LossFunctions.Reduction reduction : new LossFunctions.Reduction[] { LossFunctions.Reduction.MEAN_BY_COUNT, LossFunctions.Reduction.MEAN_BY_WEIGHT, LossFunctions.Reduction.SUM }) {
SameDiff sd = SameDiff.create();
int nOut = 4;
int minibatch = 10;
SDVariable input = sd.var("in", new int[] { -1, nOut });
SDVariable labels = sd.var("labels", new int[] { -1, nOut });
INDArray inputArr = Nd4j.randn(minibatch, nOut).muli(100);
INDArray labelsArr = Nd4j.randn(minibatch, nOut).muli(100);
LossInfo lossInfo;
INDArray expOut;
switch(fn) {
case "mse":
lossInfo = LossFunctions.mse("out", input, labels, null, reduction, 1);
expOut = inputArr.sub(labelsArr);
expOut.muli(expOut);
expOut = expOut.mean(Integer.MAX_VALUE);
break;
case "l1":
lossInfo = LossFunctions.l1("out", input, labels, null, reduction, 1);
// L1 = sum abs error
expOut = Transforms.abs(inputArr.sub(labelsArr)).sum(1);
expOut = expOut.mean(Integer.MAX_VALUE);
break;
case "l2":
lossInfo = LossFunctions.l2("out", input, labels, null, reduction, 1);
// L2 = sum squared error
expOut = Transforms.pow(inputArr.sub(labelsArr), 2.0).sum(1).mean(Integer.MAX_VALUE);
break;
case "mcxent":
lossInfo = LossFunctions.mcxent("out", input, labels, null, reduction, 1);
// mcxent = sum label * log(prob)
expOut = labelsArr.mul(Transforms.log(inputArr)).sum(1).mean(Integer.MAX_VALUE);
break;
default:
throw new RuntimeException();
}
String msg = "test: " + lossInfo.getLossName() + ", reduction=" + reduction;
log.info("*** Starting test: " + msg);
sd.associateArrayWithVariable(inputArr, input);
sd.associateArrayWithVariable(labelsArr, labels);
// System.out.println(sd.asFlatPrint());
INDArray out = sd.execAndEndResult();
assertEquals(msg, expOut, out);
System.out.println("STARTING GRADIENT CHECK");
boolean ok = GradCheckUtil.checkGradients(sd);
assertTrue(msg, ok);
}
}
}
use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class GradCheckLoss method testLossWeights4d.
@Test
public void testLossWeights4d() {
String[] weightTypes = new String[] { "none", "per-example", "per-depth", "per-height", "per-width", "per-height-width", "per-depth-height", "per-depth-width", "per-example-depth", "per-example-height", "per-example-height-width", "per-all" };
Nd4j.getRandom().setSeed(12345);
// Assume NCHW format here
int minibatch = 10;
int depth = 4;
int h = 5;
int w = 6;
for (String weightType : weightTypes) {
for (boolean binary : new boolean[] { true, false }) {
// Binary mask (like DL4J) or arbitrary weights?
int[] weightShape;
switch(weightType) {
case "none":
weightShape = null;
break;
case "per-example":
weightShape = new int[] { minibatch, 1, 1, 1 };
break;
case "per-depth":
weightShape = new int[] { 1, depth, 1, 1 };
break;
case "per-height":
weightShape = new int[] { 1, 1, h, 1 };
break;
case "per-width":
weightShape = new int[] { 1, 1, 1, w };
break;
case "per-height-width":
weightShape = new int[] { 1, 1, h, w };
break;
case "per-depth-height":
weightShape = new int[] { 1, depth, h, 1 };
break;
case "per-depth-width":
weightShape = new int[] { 1, depth, 1, w };
break;
case "per-example-depth":
weightShape = new int[] { minibatch, depth, 1, 1 };
break;
case "per-example-height":
weightShape = new int[] { minibatch, 1, h, 1 };
break;
case "per-example-height-width":
weightShape = new int[] { minibatch, 1, h, w };
break;
case "per-all":
weightShape = new int[] { minibatch, depth, h, w };
break;
default:
throw new RuntimeException("Unknown type: " + weightType);
}
INDArray weightArr = null;
if (!"none".equals(weightType)) {
if (binary) {
weightArr = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(weightShape), 0.5));
} else {
weightArr = Nd4j.rand(weightShape).muli(2.0);
}
}
for (LossFunctions.Reduction reduction : new LossFunctions.Reduction[] { LossFunctions.Reduction.MEAN_BY_COUNT, LossFunctions.Reduction.MEAN_BY_WEIGHT, LossFunctions.Reduction.SUM }) {
for (String fn : new String[] { "mse", "l1", "l2", "mcxent" }) {
SameDiff sd = SameDiff.create();
SDVariable input = sd.var("in", new int[] { -1, depth, -1, -1 });
SDVariable labels = sd.var("labels", new int[] { -1, depth, -1, -1 });
SDVariable weight = null;
if (!"none".equals(weightType)) {
weight = sd.var("weights", weightArr);
}
INDArray inputArr = Nd4j.randn(new int[] { minibatch, depth, h, w }).muli(10);
INDArray labelsArr = Nd4j.randn(new int[] { minibatch, depth, h, w }).muli(10);
LossInfo lossInfo;
switch(fn) {
case "mse":
lossInfo = LossFunctions.mse("out", input, labels, weight, reduction, 1, 2, 3);
break;
case "l1":
lossInfo = LossFunctions.l1("out", input, labels, weight, reduction, 1, 2, 3);
// L1 = sum abs error
break;
case "l2":
lossInfo = LossFunctions.l2("out", input, labels, weight, reduction, 1, 2, 3);
// L2 = sum squared error
break;
case "mcxent":
lossInfo = LossFunctions.mcxent("out", input, labels, weight, reduction, 1, 2, 3);
// mcxent = sum label * log(prob)
break;
default:
throw new RuntimeException();
}
String msg = "lossFn=" + fn + ", reduction=" + reduction + ", weightType=" + weightType + ", binaryWeight=" + binary;
log.info("*** Starting test: " + msg);
sd.associateArrayWithVariable(inputArr, input);
sd.associateArrayWithVariable(labelsArr, labels);
if (weight != null) {
sd.associateArrayWithVariable(weightArr, weight);
}
INDArray out = sd.execAndEndResult();
assertEquals(1, out.length());
boolean ok = GradCheckUtil.checkGradients(sd);
assertTrue(msg, ok);
}
}
}
}
}
Aggregations