Search in sources :

Example 1 with RectifedLinear

use of org.nd4j.linalg.api.ops.impl.transforms.RectifedLinear in project nd4j by deeplearning4j.

the class DefaultOpFactory method createTransform.

/**
 * @param name
 * @param x
 * @param y
 * @param z
 * @param extraArgs
 * @return
 */
@Override
public TransformOp createTransform(String name, INDArray x, INDArray y, INDArray z, Object[] extraArgs) {
    TransformOp op = null;
    switch(name) {
        case "_softmaxderivative":
            op = new org.nd4j.linalg.api.ops.impl.transforms.SoftMaxDerivative(x, z);
            break;
        case "set":
            op = new org.nd4j.linalg.api.ops.impl.transforms.Set(x, y, z, z.length());
            break;
        case "relu":
            op = new RectifedLinear(x, z, x.length(), extraArgs == null || extraArgs[0] == null ? 0.0 : (double) extraArgs[0]);
            break;
        case "step":
            op = new Step(x, y, z, x.length(), extraArgs == null || extraArgs[0] == null ? 0.0 : (double) extraArgs[0]);
            break;
        case "pow":
            op = new Pow(x, z, (double) extraArgs[0]);
            break;
        default:
            try {
                if (y == null)
                    op = (TransformOp) DifferentialFunctionClassHolder.getInstance().getInstance(name).getClass().getConstructor(INDArray.class, INDArray.class).newInstance(x, z);
                else
                    op = (TransformOp) DifferentialFunctionClassHolder.getInstance().getInstance(name).getClass().getConstructor(INDArray.class, INDArray.class, INDArray.class).newInstance(x, y, z);
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
    }
    /*

        switch (opName) {
            case "set":
                op = new org.nd4j.linalg.api.ops.impl.transforms.Set(x,y,z,z.length());
                break;
            case "relu":
                op = new RectifedLinear(x, z, x.length(),extraArgs == null || extraArgs[0] == null ? 0.0 : (double) extraArgs[0]);
                break;
            case "step":
                op = new Step(x,y,z,x.length(),extraArgs == null || extraArgs[0] == null  ? 0.0 : (double) extraArgs[0]);
                break;
            case "abs":
                op = new Abs(x, z);
                break;
            case "acos":
                op = new ACos(x, z);
                break;
            case "asin":
                op = new ASin(x, z);
                break;
            case "atan":
                op = new ATan(x, z);
                break;
            case "ceil":
                op = new Ceil(x, z);
                break;
            case "cos":
                op = new Cos(x, z);
                break;
            case "exp":
                op = new Exp(x, z);
                break;
            case "elu":
                op = new ELU(x, z);
                break;
            case "floor":
                op = new Floor(x, z);
                break;
            case "hardtanh":
                op = new HardTanh(x, z);
                break;
            case "hardsigmoid":
                op = new HardSigmoid(x, z);
                break;
            case "identity":
                op = new Identity(x, z);
                break;
            case "leakyrelu":
                op = new LeakyReLU(x, z);
                break;
            case "log":
                op = new Log(x, z);
                break;
            case "logsoftmax":
                op = new LogSoftMax(x, z);
                break;
            case "maxout":
                op = new MaxOut(x, z);
                break;
            case "negative":
                op = new Negative(x, z);
                break;
            case "pow":
                op = new Pow(x, z, (double) extraArgs[0]);
                break;
            case "round":
                op = new Round(x, z);
                break;
            case "sigmoid":
                op = new Sigmoid(x, z);
                break;
            case "sign":
                op = new Sign(x, z);
                break;
            case "sin":
                op = new Sin(x, z);
                break;
            case "softsign":
                op = new SoftSign(x, z);
                break;
            case "sqrt":
                op = new Sqrt(x, z);
                break;
            case "stabilize":
                op = new Stabilize(x, z, 1);
                break;
            case "tanh":
                op = new Tanh(x, z);
                break;
            case "rationaltanh":
                op = new RationalTanh(x, z);
                break;
            case "timesoneminus":
                op = new TimesOneMinus(x, z);
                break;
            case "softmaxderivative":
                op = new org.nd4j.linalg.api.ops.impl.transforms.SoftMaxDerivative(x, z);
                break;
            case "softmax":
                op = new SoftMax(x, z);
                break;
            case "softplus":
                op = new SoftPlus(x, z);
                break;
            case "cube":
                op = new Cube(x, z);
                break;
            case "sigmoidderivative":
                op = new SigmoidDerivative(x,z);
                break;
            case "hard_sigmoidderivative":
                op = new HardSigmoidDerivative(x,z);
                break;
            case "hardtanhderivative":
                op = new HardTanhDerivative(x,z);
                break;
            case "tanhderivative":
                op = new TanhDerivative(x,z);
                break;
            case "leakyreluderivative":
                op = new LeakyReLUDerivative(x,z);
                break;
            case "mul":
                op = new MulOp(x,y,z);
                break;
            case "add":
                op = new AddOp(x,y,z);
                break;
            case "sub":
                op = new SubOp(x,y,z);
                break;
            case "div":
                op = new DivOp(x,y,z);
                break;
            case "rdiv":
                op = new RDivOp(x,y,z);
                break;
            case "rsub":
                op = new RSubOp(x,y,z);
                break;
            case "neg":
                op = new Negative(x,z);
                break;
            default:
                throw new ND4JIllegalStateException("No op found " + opName);
        }
*/
    op.setExtraArgs(extraArgs);
    return op;
}
Also used : RectifedLinear(org.nd4j.linalg.api.ops.impl.transforms.RectifedLinear) org.nd4j.linalg.api.ops(org.nd4j.linalg.api.ops) Pow(org.nd4j.linalg.api.ops.impl.transforms.Pow) Step(org.nd4j.linalg.api.ops.impl.transforms.Step)

Example 2 with RectifedLinear

use of org.nd4j.linalg.api.ops.impl.transforms.RectifedLinear in project nd4j by deeplearning4j.

the class ActivationRReLU method getActivation.

@Override
public INDArray getActivation(INDArray in, boolean training) {
    if (training) {
        this.alpha = Nd4j.rand(in.shape(), l, u, Nd4j.getRandom());
        INDArray inTimesAlpha = in.mul(alpha);
        BooleanIndexing.replaceWhere(in, inTimesAlpha, Conditions.lessThan(0));
    } else {
        this.alpha = null;
        double a = 0.5 * (l + u);
        return Nd4j.getExecutioner().execAndReturn(new RectifedLinear(in, a));
    }
    return in;
}
Also used : RectifedLinear(org.nd4j.linalg.api.ops.impl.transforms.RectifedLinear) INDArray(org.nd4j.linalg.api.ndarray.INDArray)

Aggregations

RectifedLinear (org.nd4j.linalg.api.ops.impl.transforms.RectifedLinear)2 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1 org.nd4j.linalg.api.ops (org.nd4j.linalg.api.ops)1 Pow (org.nd4j.linalg.api.ops.impl.transforms.Pow)1 Step (org.nd4j.linalg.api.ops.impl.transforms.Step)1