use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class LossFunctions method nonZeroCount.
/**
* Determine the number of weight entries that are non-zero, after broadcasting
*
* @param weights
* @param labels
* @return
*/
private static SDVariable nonZeroCount(SDVariable weights, SDVariable labels) {
SameDiff sd = weights.getSameDiff();
SDVariable present = sd.neq(weights, 0.0);
SDVariable presentBroadcast = sd.zerosLike(labels).add(present);
return sd.sum(presentBroadcast);
}
use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class LossFunctions method doReduce.
/**
* Perform the final reduction on the loss function
*
* @param sd
* @param outputName
* @param isMean
* @param b
* @param reduction
* @param preReduceLoss
* @param label
* @param weights
* @param dimensions
* @return
*/
private static LossInfo doReduce(SameDiff sd, String outputName, boolean isMean, LossInfo.Builder b, Reduction reduction, SDVariable preReduceLoss, SDVariable label, SDVariable weights, int[] dimensions) {
switch(reduction) {
case NONE:
// Return same shape as predictions/labels
b.loss(preReduceLoss);
break;
case SPECIFIED_DIMS:
// Reduce along specified dimensions
if (isMean) {
// Example: MSE + mean along examples
b.loss(sd.mean(outputName, preReduceLoss, dimensions));
} else {
// Example: L1 loss (sum) + mean along examples
b.loss(sd.sum(outputName, preReduceLoss, dimensions));
}
case SUM:
if (isMean) {
// Example: MSE (mean) + sum along examples
SDVariable m = sd.mean(preReduceLoss, dimensions);
b.loss(sd.sum(outputName, m));
} else {
// Example: L1 loss (sum) + sum along examples -> sum along all dimensions
b.loss(sd.sum(outputName, preReduceLoss));
}
break;
case MEAN_BY_WEIGHT:
SDVariable weightSum = sd.sum(weights);
if (isMean) {
// Example: MSE (mean) + mean by weights over examples
// reduce along dims + reduce along remaining dims == reduce along *all* dims
SDVariable m2 = sd.mean(preReduceLoss);
b.loss(m2.div(outputName, weightSum));
} else {
// Example: L1 (sum) + mean by weights over examples
SDVariable sum = sd.sum(preReduceLoss, dimensions);
b.loss(sum.div(outputName, weightSum));
}
break;
case MEAN_BY_COUNT:
SDVariable nonZeroWeights = nonZeroCount(weights, label);
SDVariable r;
if (isMean) {
// Example: MSE (mean) + mean by count over examples
r = sd.sum(preReduceLoss);
} else {
// Example: L1 (sum) + mean by count over examples
SDVariable sum = sd.sum(preReduceLoss, dimensions);
r = sd.mean(sum);
}
b.loss(r.div(outputName, nonZeroWeights));
break;
default:
throw new RuntimeException("Unknown reduction: " + reduction);
}
return b.build();
}
use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class Diag method doDiff.
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
SDVariable grad = i_v.get(0);
SDVariable ret = sameDiff.diagPart(grad);
return Arrays.asList(ret);
}
use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class DynamicStitch method doDiff.
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
// DynamicPartition and DynamicStitch are mutually inverse
SDVariable gradient = i_v.get(0);
SDVariable[] partitionData = indices.clone();
for (int i = 0; i < indices.length; i++) partitionData[i] = sameDiff.onesLike(indices[i]).mul(i);
SDVariable partitions = sameDiff.dynamicStitch(partitionData, indices);
SDVariable[] ret = sameDiff.dynamicPartition(gradient, partitions, numPartitions);
return new ArrayList<>(Arrays.asList(ret));
}
use of org.nd4j.autodiff.samediff.SDVariable in project nd4j by deeplearning4j.
the class GradCheckTransforms method testDynamicStitch.
@Test
public void testDynamicStitch() {
SameDiff sd = SameDiff.create();
INDArray ia = Nd4j.create(new float[] { 5, 1, 3 }, new int[] { 1, 3 });
INDArray ib = Nd4j.create(new float[] { 7, 2, 4 }, new int[] { 1, 3 });
INDArray indexA = Nd4j.create(new float[] { 0, 1, 4 }, new int[] { 1, 3 });
INDArray indexB = Nd4j.create(new float[] { 2, 3, 5 }, new int[] { 1, 3 });
INDArray expOut = Nd4j.create(new int[] { 1, 6 });
DynamicCustomOp dynamicStitch = DynamicCustomOp.builder("dynamic_stitch").addInputs(indexA, indexB, ia, ib).addOutputs(expOut).build();
Nd4j.getExecutioner().exec(dynamicStitch);
SDVariable in1 = sd.var("in1", new int[] { 1, 3 });
SDVariable in2 = sd.var("in2", new int[] { 1, 3 });
SDVariable index1 = sd.var("index1", new int[] { 1, 3 });
SDVariable index2 = sd.var("index2", new int[] { 1, 3 });
sd.associateArrayWithVariable(ia, in1);
sd.associateArrayWithVariable(ib, in2);
sd.associateArrayWithVariable(indexA, index1);
sd.associateArrayWithVariable(indexB, index2);
SDVariable t = sd.dynamicStitch(new SDVariable[] { index1, index2 }, new SDVariable[] { in1, in2 });
SDVariable loss = sd.mean("loss", t);
sd.exec();
INDArray out = t.getArr();
if (!expOut.equals(out)) {
log.error("forward failed");
}
}
Aggregations