use of org.nd4j.linalg.api.ops.impl.transforms.RectifedLinear in project nd4j by deeplearning4j.
the class DefaultOpFactory method createTransform.
/**
* @param name
* @param x
* @param y
* @param z
* @param extraArgs
* @return
*/
@Override
public TransformOp createTransform(String name, INDArray x, INDArray y, INDArray z, Object[] extraArgs) {
TransformOp op = null;
switch(name) {
case "_softmaxderivative":
op = new org.nd4j.linalg.api.ops.impl.transforms.SoftMaxDerivative(x, z);
break;
case "set":
op = new org.nd4j.linalg.api.ops.impl.transforms.Set(x, y, z, z.length());
break;
case "relu":
op = new RectifedLinear(x, z, x.length(), extraArgs == null || extraArgs[0] == null ? 0.0 : (double) extraArgs[0]);
break;
case "step":
op = new Step(x, y, z, x.length(), extraArgs == null || extraArgs[0] == null ? 0.0 : (double) extraArgs[0]);
break;
case "pow":
op = new Pow(x, z, (double) extraArgs[0]);
break;
default:
try {
if (y == null)
op = (TransformOp) DifferentialFunctionClassHolder.getInstance().getInstance(name).getClass().getConstructor(INDArray.class, INDArray.class).newInstance(x, z);
else
op = (TransformOp) DifferentialFunctionClassHolder.getInstance().getInstance(name).getClass().getConstructor(INDArray.class, INDArray.class, INDArray.class).newInstance(x, y, z);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
/*
switch (opName) {
case "set":
op = new org.nd4j.linalg.api.ops.impl.transforms.Set(x,y,z,z.length());
break;
case "relu":
op = new RectifedLinear(x, z, x.length(),extraArgs == null || extraArgs[0] == null ? 0.0 : (double) extraArgs[0]);
break;
case "step":
op = new Step(x,y,z,x.length(),extraArgs == null || extraArgs[0] == null ? 0.0 : (double) extraArgs[0]);
break;
case "abs":
op = new Abs(x, z);
break;
case "acos":
op = new ACos(x, z);
break;
case "asin":
op = new ASin(x, z);
break;
case "atan":
op = new ATan(x, z);
break;
case "ceil":
op = new Ceil(x, z);
break;
case "cos":
op = new Cos(x, z);
break;
case "exp":
op = new Exp(x, z);
break;
case "elu":
op = new ELU(x, z);
break;
case "floor":
op = new Floor(x, z);
break;
case "hardtanh":
op = new HardTanh(x, z);
break;
case "hardsigmoid":
op = new HardSigmoid(x, z);
break;
case "identity":
op = new Identity(x, z);
break;
case "leakyrelu":
op = new LeakyReLU(x, z);
break;
case "log":
op = new Log(x, z);
break;
case "logsoftmax":
op = new LogSoftMax(x, z);
break;
case "maxout":
op = new MaxOut(x, z);
break;
case "negative":
op = new Negative(x, z);
break;
case "pow":
op = new Pow(x, z, (double) extraArgs[0]);
break;
case "round":
op = new Round(x, z);
break;
case "sigmoid":
op = new Sigmoid(x, z);
break;
case "sign":
op = new Sign(x, z);
break;
case "sin":
op = new Sin(x, z);
break;
case "softsign":
op = new SoftSign(x, z);
break;
case "sqrt":
op = new Sqrt(x, z);
break;
case "stabilize":
op = new Stabilize(x, z, 1);
break;
case "tanh":
op = new Tanh(x, z);
break;
case "rationaltanh":
op = new RationalTanh(x, z);
break;
case "timesoneminus":
op = new TimesOneMinus(x, z);
break;
case "softmaxderivative":
op = new org.nd4j.linalg.api.ops.impl.transforms.SoftMaxDerivative(x, z);
break;
case "softmax":
op = new SoftMax(x, z);
break;
case "softplus":
op = new SoftPlus(x, z);
break;
case "cube":
op = new Cube(x, z);
break;
case "sigmoidderivative":
op = new SigmoidDerivative(x,z);
break;
case "hard_sigmoidderivative":
op = new HardSigmoidDerivative(x,z);
break;
case "hardtanhderivative":
op = new HardTanhDerivative(x,z);
break;
case "tanhderivative":
op = new TanhDerivative(x,z);
break;
case "leakyreluderivative":
op = new LeakyReLUDerivative(x,z);
break;
case "mul":
op = new MulOp(x,y,z);
break;
case "add":
op = new AddOp(x,y,z);
break;
case "sub":
op = new SubOp(x,y,z);
break;
case "div":
op = new DivOp(x,y,z);
break;
case "rdiv":
op = new RDivOp(x,y,z);
break;
case "rsub":
op = new RSubOp(x,y,z);
break;
case "neg":
op = new Negative(x,z);
break;
default:
throw new ND4JIllegalStateException("No op found " + opName);
}
*/
op.setExtraArgs(extraArgs);
return op;
}
use of org.nd4j.linalg.api.ops.impl.transforms.RectifedLinear in project nd4j by deeplearning4j.
the class ActivationRReLU method getActivation.
@Override
public INDArray getActivation(INDArray in, boolean training) {
if (training) {
this.alpha = Nd4j.rand(in.shape(), l, u, Nd4j.getRandom());
INDArray inTimesAlpha = in.mul(alpha);
BooleanIndexing.replaceWhere(in, inTimesAlpha, Conditions.lessThan(0));
} else {
this.alpha = null;
double a = 0.5 * (l + u);
return Nd4j.getExecutioner().execAndReturn(new RectifedLinear(in, a));
}
return in;
}
Aggregations