use of org.nd4j.linalg.activations.impl.ActivationSoftmax in project deeplearning4j by deeplearning4j.
the class RnnOutputLayer method output.
@Override
public INDArray output(boolean training) {
//Assume that input is 3d
if (input.rank() != 3)
throw new IllegalArgumentException("input must be rank 3");
INDArray preOutput2d = preOutput2d(training);
//if(conf.getLayer().getActivationFunction().equals("softmax")) {
if (conf.getLayer().getActivationFn() instanceof ActivationSoftmax) {
INDArray out2d = Nd4j.getExecutioner().execAndReturn(new SoftMax(preOutput2d));
if (maskArray != null) {
out2d.muliColumnVector(maskArray);
}
return TimeSeriesUtils.reshape2dTo3d(out2d, input.size(0));
}
if (training)
applyDropOutIfNecessary(training);
INDArray origInput = input;
this.input = TimeSeriesUtils.reshape3dTo2d(input);
INDArray out = super.activate(true);
this.input = origInput;
if (maskArray != null) {
out.muliColumnVector(maskArray);
}
return TimeSeriesUtils.reshape2dTo3d(out, input.size(0));
}
use of org.nd4j.linalg.activations.impl.ActivationSoftmax in project nd4j by deeplearning4j.
the class LossBinaryXENT method scoreArray.
private INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
if (labels.size(1) != preOutput.size(1)) {
throw new IllegalArgumentException("Labels array numColumns (size(1) = " + labels.size(1) + ") does not match output layer" + " number of outputs (nOut = " + preOutput.size(1) + ") ");
}
INDArray scoreArr;
if (activationFn instanceof ActivationSoftmax) {
// Use LogSoftMax op to avoid numerical issues when calculating score
INDArray logsoftmax = Nd4j.getExecutioner().execAndReturn(new LogSoftMax(preOutput.dup()));
scoreArr = logsoftmax.muli(labels);
} else {
// INDArray output = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(activationFn, preOutput.dup()));
INDArray output = activationFn.getActivation(preOutput.dup(), true);
if (clipEps > 0.0) {
CustomOp op = DynamicCustomOp.builder("clipbyvalue").addInputs(output).callInplace(true).addFloatingPointArguments(clipEps, 1.0 - clipEps).build();
Nd4j.getExecutioner().exec(op);
}
scoreArr = Transforms.log(output, true).muli(labels);
INDArray secondTerm = output.rsubi(1);
Transforms.log(secondTerm, false);
secondTerm.muli(labels.rsub(1));
scoreArr.addi(secondTerm);
}
// Weighted loss function
if (weights != null) {
if (weights.length() != preOutput.size(1)) {
throw new IllegalStateException("Weights vector (length " + weights.length() + ") does not match output.size(1)=" + preOutput.size(1));
}
scoreArr.muliRowVector(weights);
}
if (mask != null) {
LossUtil.applyMask(scoreArr, mask);
}
return scoreArr;
}
use of org.nd4j.linalg.activations.impl.ActivationSoftmax in project nd4j by deeplearning4j.
the class LossMCXENT method scoreArray.
private INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
if (labels.size(1) != preOutput.size(1)) {
throw new IllegalArgumentException("Labels array numColumns (size(1) = " + labels.size(1) + ") does not match output layer" + " number of outputs (nOut = " + preOutput.size(1) + ") ");
}
INDArray output = activationFn.getActivation(preOutput.dup(), true);
if (activationFn instanceof ActivationSoftmax && softmaxClipEps > 0.0) {
BooleanIndexing.replaceWhere(output, softmaxClipEps, Conditions.lessThan(softmaxClipEps));
BooleanIndexing.replaceWhere(output, 1.0 - softmaxClipEps, Conditions.greaterThan(1.0 - softmaxClipEps));
}
INDArray scoreArr = Transforms.log(output, false).muli(labels);
// Weighted loss function
if (weights != null) {
if (weights.length() != scoreArr.size(1)) {
throw new IllegalStateException("Weights vector (length " + weights.length() + ") does not match output.size(1)=" + preOutput.size(1));
}
scoreArr.muliRowVector(weights);
}
if (mask != null) {
LossUtil.applyMask(scoreArr, mask);
}
return scoreArr;
}
use of org.nd4j.linalg.activations.impl.ActivationSoftmax in project nd4j by deeplearning4j.
the class LossMCXENT method computeGradient.
@Override
public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
if (labels.size(1) != preOutput.size(1)) {
throw new IllegalArgumentException("Labels array numColumns (size(1) = " + labels.size(1) + ") does not match output layer" + " number of outputs (nOut = " + preOutput.size(1) + ") ");
}
INDArray grad;
// INDArray output = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(activationFn, preOutput.dup()));
INDArray output = activationFn.getActivation(preOutput.dup(), true);
if (activationFn instanceof ActivationSoftmax) {
if (mask != null && LossUtil.isPerOutputMasking(output, mask)) {
throw new UnsupportedOperationException("Per output masking for MCXENT + softmax: not supported");
}
// Weighted loss function
if (weights != null) {
if (weights.length() != output.size(1)) {
throw new IllegalStateException("Weights vector (length " + weights.length() + ") does not match output.size(1)=" + output.size(1));
}
INDArray temp = labels.mulRowVector(weights);
INDArray col = temp.sum(1);
grad = output.mulColumnVector(col).sub(temp);
} else {
grad = output.subi(labels);
}
} else {
INDArray dLda = output.rdivi(labels).negi();
// TODO activation function with weights
grad = activationFn.backprop(preOutput, dLda).getFirst();
// Weighted loss function
if (weights != null) {
if (weights.length() != output.size(1)) {
throw new IllegalStateException("Weights vector (length " + weights.length() + ") does not match output.size(1)=" + output.size(1));
}
grad.muliRowVector(weights);
}
}
// Loss function with masking
if (mask != null) {
LossUtil.applyMask(grad, mask);
}
return grad;
}
Aggregations