Search in sources :

Example 1 with BaseRandomOp

use of org.nd4j.linalg.api.ops.random.BaseRandomOp in project nd4j by deeplearning4j.

the class OpsMappingTests method getOperations.

protected List<Operation> getOperations(@NonNull Op.Type type) {
    val list = new ArrayList<Operation>();
    Reflections f = new Reflections(new ConfigurationBuilder().filterInputsBy(new FilterBuilder().include(FilterBuilder.prefix("org.nd4j.*")).exclude("^(?!.*\\.class$).*$")).setUrls(ClasspathHelper.forPackage("org.nd4j")).setScanners(new SubTypesScanner()));
    switch(type) {
        case SUMMARYSTATS:
            {
                Set<Class<? extends Variance>> clazzes = f.getSubTypesOf(Variance.class);
                for (Class<? extends DifferentialFunction> clazz : clazzes) addOperation(clazz, list);
            }
            break;
        case RANDOM:
            {
                Set<Class<? extends BaseRandomOp>> clazzes = f.getSubTypesOf(BaseRandomOp.class);
                for (Class<? extends DifferentialFunction> clazz : clazzes) addOperation(clazz, list);
            }
            break;
        case INDEXREDUCE:
            {
                Set<Class<? extends BaseIndexAccumulation>> clazzes = f.getSubTypesOf(BaseIndexAccumulation.class);
                for (Class<? extends DifferentialFunction> clazz : clazzes) addOperation(clazz, list);
            }
            break;
        case REDUCE3:
        case REDUCE:
            {
                Set<Class<? extends BaseAccumulation>> clazzes = f.getSubTypesOf(BaseAccumulation.class);
                for (Class<? extends DifferentialFunction> clazz : clazzes) addOperation(clazz, list);
            }
            break;
        case BROADCAST:
            {
                Set<Class<? extends BaseBroadcastOp>> clazzes = f.getSubTypesOf(BaseBroadcastOp.class);
                for (Class<? extends DifferentialFunction> clazz : clazzes) addOperation(clazz, list);
            }
            break;
        case SCALAR:
            {
                Set<Class<? extends BaseScalarOp>> clazzes = f.getSubTypesOf(BaseScalarOp.class);
                for (Class<? extends DifferentialFunction> clazz : clazzes) addOperation(clazz, list);
            }
            break;
        case PAIRWISE:
        case TRANSFORM:
            {
                Set<Class<? extends BaseTransformOp>> clazzes = f.getSubTypesOf(BaseTransformOp.class);
                for (Class<? extends DifferentialFunction> clazz : clazzes) addOperation(clazz, list);
            }
            break;
        case CUSTOM:
            {
                Set<Class<? extends DynamicCustomOp>> clazzes = f.getSubTypesOf(DynamicCustomOp.class);
                for (Class<? extends DifferentialFunction> clazz : clazzes) {
                    if (clazz.getSimpleName().equalsIgnoreCase("dynamiccustomop"))
                        continue;
                    addOperation(clazz, list);
                }
            }
            break;
    }
    log.info("Group: {}; List size: {}", type, list.size());
    return list;
}
Also used : lombok.val(lombok.val) ConfigurationBuilder(org.reflections.util.ConfigurationBuilder) Set(java.util.Set) ArrayList(java.util.ArrayList) BaseRandomOp(org.nd4j.linalg.api.ops.random.BaseRandomOp) Variance(org.nd4j.linalg.api.ops.impl.accum.Variance) FilterBuilder(org.reflections.util.FilterBuilder) SubTypesScanner(org.reflections.scanners.SubTypesScanner) DifferentialFunction(org.nd4j.autodiff.functions.DifferentialFunction) Reflections(org.reflections.Reflections)

Aggregations

ArrayList (java.util.ArrayList)1 Set (java.util.Set)1 lombok.val (lombok.val)1 DifferentialFunction (org.nd4j.autodiff.functions.DifferentialFunction)1 Variance (org.nd4j.linalg.api.ops.impl.accum.Variance)1 BaseRandomOp (org.nd4j.linalg.api.ops.random.BaseRandomOp)1 Reflections (org.reflections.Reflections)1 SubTypesScanner (org.reflections.scanners.SubTypesScanner)1 ConfigurationBuilder (org.reflections.util.ConfigurationBuilder)1 FilterBuilder (org.reflections.util.FilterBuilder)1