use of org.nd4j.linalg.api.ops.impl.transforms.SoftMaxDerivative in project nd4j by deeplearning4j.
the class DerivativeTests method softmaxsimpleLossTest.
@Test
public void softmaxsimpleLossTest() {
/*
Softmax derivative is correct if it is standalone
But when we are applying it in the chain rule the current derivative function is incomplete.
For this test, I am assuming that the function off interest is just MSE
What the fix is:
We need the derivative of a softmax needs to return a rank 2 matrix.
Right now we get only the diagonal elements of this matrix
http://stats.stackexchange.com/questions/79454/softmax-layer-in-a-neural-network
*/
// random array represeting preout
INDArray X = Nd4j.rand(1, 2);
// preout transformed to y_hat with softmax
INDArray YHat = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", X.dup()));
// hard coding something to construct a function with, using MSE
INDArray Y = Nd4j.create(new double[][] { { 0.123, 1 - 0.123 } });
// This is the MSE now
double lossHere = Transforms.pow(Y.sub(YHat), 2).sumNumber().doubleValue();
INDArray softmaxDer = Nd4j.getExecutioner().execAndReturn(new SoftMaxDerivative(X.dup()));
// the way we apply the chain rule now is 2*(y-yhat)*softmaxder
INDArray dLdY = Y.sub(YHat).mul(-2);
INDArray currentGradient = dLdY.mul(softmaxDer);
// what I think we should be doing
// we have x0, x1 -> y0,y1
// we need the derivatives of the output of the softmax wrt every input (x0,x1)
// we only have dy0/dx0 and dy1/dx1
// we also need dy0/dx1 and dy1/dx0
// the below is the chain rule in calc applied when L is a function of y0,y1; y0 and y1 are in turn functions of BOTH (x0 and x1)
// dL/dx0 = (dl/dy0) * (dy0/dx0) + (dL/dy1) * (dy1/dx0)
// dL/dx1 = (dl/dy0) * (dy0/dx1) + (dL/dy1) * (dy1/dx1)
// worked it out on paper and googled it (should have googled first, gave formula from link above)
// dy0/dx0 = y0*(1-y0) = y0*y1
// dy1/dx0 = -y1*(1-y1) = -y0*y1
// dy0/dx1 = -y0*(1-y0) = -y0*y1
// dy1/dx1 = y1*(1-y1) = y0*y1
// [ dL/dy0 dL/dy1] [[dy0/dx0 dy1/dx0] [dy0/dx1 dy1/dx1]]
double y0y1 = softmaxDer.getDouble(0, 0);
// hack but this is what we need to implement, straightforward here but complicated for >2
// INDArray mysoftmaxDer = Nd4j.create(new double[][] {{y0y1,y0y1*-1},{-1*y0y1,y0y1}});
INDArray mysoftmaxDer = correctSoftmax(X);
INDArray myGradient = mysoftmaxDer.mulRowVector(dLdY).sum(1);
double epsilon = 0.0001;
INDArray Xiplus, Ximinus;
INDArray YHatplus, YHatminus;
double lossplus, lossminus;
INDArray numGradient = Nd4j.zeros(1, 2);
for (int i = 0; i < 2; i++) {
/* change X one value one at a time */
// +epsilon
double x = X.getDouble(0, i);
Xiplus = X.dup();
Xiplus.put(0, i, x + epsilon);
YHatplus = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", Xiplus.dup()));
lossplus = Transforms.pow(Y.sub(YHatplus), 2).sumNumber().doubleValue();
// -epsilon
Ximinus = X.dup();
Ximinus.put(0, i, x - epsilon);
YHatminus = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", Ximinus.dup()));
lossminus = Transforms.pow(Y.sub(YHatminus), 2).sumNumber().doubleValue();
double gradienti = (lossplus - lossminus) / (2 * epsilon);
numGradient.put(0, i, gradienti);
}
System.out.println("=========================");
System.out.println("NUMERICAL:");
System.out.println(numGradient);
System.out.println("\nCURRENTLY:");
System.out.println(currentGradient);
System.out.println("\nMY GRADIENT:");
System.out.println(myGradient + "\n");
System.out.println("Because of the nature of the derivative of the softmax for length = 2, our current method will make it off by a factor of 2");
System.out.println("=========================");
}
use of org.nd4j.linalg.api.ops.impl.transforms.SoftMaxDerivative in project nd4j by deeplearning4j.
the class DerivativeTests method softmaxsimplelongerlengthLossTest.
@Test
public void softmaxsimplelongerlengthLossTest() {
/*
Read comments in earlier test for length = 2
*/
// random array represeting preout
int someLength = 7;
INDArray X = Nd4j.rand(1, someLength);
// preout transformed to y_hat with softmax
INDArray YHat = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", X.dup()));
// hard coding something to construct a function with, using MSE
INDArray temp = Nd4j.rand(1, someLength);
INDArray Y = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", temp));
// This is the MSE now
double lossHere = Transforms.pow(Y.sub(YHat), 2).sumNumber().doubleValue();
INDArray softmaxDer = Nd4j.getExecutioner().execAndReturn(new SoftMaxDerivative(X.dup()));
// the way we apply the chain rule now is 2*(y-yhat)*softmaxder
INDArray dLdY = Y.sub(YHat).mul(-2);
INDArray currentGradient = dLdY.mul(softmaxDer);
INDArray mysoftmaxDer = correctSoftmax(X);
INDArray myGradient = mysoftmaxDer.mulRowVector(dLdY).sum(1);
double epsilon = 0.0001;
INDArray Xiplus, Ximinus;
INDArray YHatplus, YHatminus;
double lossplus, lossminus;
INDArray numGradient = Nd4j.zeros(1, someLength);
for (int i = 0; i < someLength; i++) {
/* change X one value one at a time */
// +epsilon
double x = X.getDouble(0, i);
Xiplus = X.dup();
Xiplus.put(0, i, x + epsilon);
YHatplus = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", Xiplus.dup()));
lossplus = Transforms.pow(Y.sub(YHatplus), 2).sumNumber().doubleValue();
// -epsilon
Ximinus = X.dup();
Ximinus.put(0, i, x - epsilon);
YHatminus = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", Ximinus.dup()));
lossminus = Transforms.pow(Y.sub(YHatminus), 2).sumNumber().doubleValue();
double gradienti = (lossplus - lossminus) / (2 * epsilon);
numGradient.put(0, i, gradienti);
}
System.out.println("=========================");
System.out.println("NUMERICAL GRADIENT:");
System.out.println(new NDArrayStrings(6).format(numGradient).toString());
System.out.println("\nANALYTIC USING EXISTING SOFTMAX DER:");
System.out.println(new NDArrayStrings(6).format(currentGradient).toString());
System.out.println("\nGRADIENT USING MY VERSION OF SOFTMAX DER:");
System.out.println(new NDArrayStrings(6).format(myGradient).toString());
System.out.println("=========================");
}
Aggregations