Search in sources :

Example 1 with DL4JSubTypesScanner

use of org.deeplearning4j.util.reflections.DL4JSubTypesScanner in project deeplearning4j by deeplearning4j.

the class NeuralNetConfiguration method registerSubtypes.

private static synchronized void registerSubtypes(ObjectMapper mapper) {
    //Register concrete subtypes for JSON serialization
    List<Class<?>> classes = Arrays.<Class<?>>asList(InputPreProcessor.class, ILossFunction.class, IActivation.class, Layer.class, GraphVertex.class, ReconstructionDistribution.class);
    List<String> classNames = new ArrayList<>(6);
    for (Class<?> c : classes) classNames.add(c.getName());
    // First: scan the classpath and find all instances of the 'baseClasses' classes
    if (subtypesClassCache == null) {
        //Check system property:
        String prop = System.getProperty(CUSTOM_FUNCTIONALITY);
        if (prop != null && !Boolean.parseBoolean(prop)) {
            subtypesClassCache = Collections.emptySet();
        } else {
            List<Class<?>> interfaces = Arrays.<Class<?>>asList(InputPreProcessor.class, ILossFunction.class, IActivation.class, ReconstructionDistribution.class);
            List<Class<?>> classesList = Arrays.<Class<?>>asList(Layer.class, GraphVertex.class);
            Collection<URL> urls = ClasspathHelper.forClassLoader();
            List<URL> scanUrls = new ArrayList<>();
            for (URL u : urls) {
                String path = u.getPath();
                if (!path.matches(".*/jre/lib/.*jar")) {
                    //Skip JRE/JDK JARs
                    scanUrls.add(u);
                }
            }
            Reflections reflections = new Reflections(new ConfigurationBuilder().filterInputsBy(new FilterBuilder().exclude(//Consider only .class files (to avoid debug messages etc. on .dlls, etc
            "^(?!.*\\.class$).*$").exclude("^org.nd4j.*").exclude("^org.datavec.*").exclude(//JavaCPP
            "^org.bytedeco.*").exclude(//Jackson
            "^com.fasterxml.*").exclude(//Apache commons, Spark, log4j etc
            "^org.apache.*").exclude("^org.projectlombok.*").exclude("^com.twelvemonkeys.*").exclude("^org.joda.*").exclude("^org.slf4j.*").exclude("^com.google.*").exclude("^org.reflections.*").exclude(//Logback
            "^ch.qos.*")).addUrls(scanUrls).setScanners(new DL4JSubTypesScanner(interfaces, classesList)));
            org.reflections.Store store = reflections.getStore();
            Iterable<String> subtypesByName = store.getAll(DL4JSubTypesScanner.class.getSimpleName(), classNames);
            Set<? extends Class<?>> subtypeClasses = Sets.newHashSet(ReflectionUtils.forNames(subtypesByName));
            subtypesClassCache = new HashSet<>();
            for (Class<?> c : subtypeClasses) {
                if (Modifier.isAbstract(c.getModifiers()) || Modifier.isInterface(c.getModifiers())) {
                    //log.info("Skipping abstract/interface: {}",c);
                    continue;
                }
                subtypesClassCache.add(c);
            }
        }
    }
    //Second: get all currently registered subtypes for this mapper
    Set<Class<?>> registeredSubtypes = new HashSet<>();
    for (Class<?> c : classes) {
        AnnotatedClass ac = AnnotatedClass.construct(c, mapper.getSerializationConfig().getAnnotationIntrospector(), null);
        Collection<NamedType> types = mapper.getSubtypeResolver().collectAndResolveSubtypes(ac, mapper.getSerializationConfig(), mapper.getSerializationConfig().getAnnotationIntrospector());
        for (NamedType nt : types) {
            registeredSubtypes.add(nt.getType());
        }
    }
    //Third: register all _concrete_ subtypes that are not already registered
    List<NamedType> toRegister = new ArrayList<>();
    for (Class<?> c : subtypesClassCache) {
        //Check if it's concrete or abstract...
        if (Modifier.isAbstract(c.getModifiers()) || Modifier.isInterface(c.getModifiers())) {
            //log.info("Skipping abstract/interface: {}",c);
            continue;
        }
        if (!registeredSubtypes.contains(c)) {
            String name;
            if (ClassUtils.isInnerClass(c)) {
                Class<?> c2 = c.getDeclaringClass();
                name = c2.getSimpleName() + "$" + c.getSimpleName();
            } else {
                name = c.getSimpleName();
            }
            toRegister.add(new NamedType(c, name));
            if (log.isDebugEnabled()) {
                for (Class<?> baseClass : classes) {
                    if (baseClass.isAssignableFrom(c)) {
                        log.debug("Registering class for JSON serialization: {} as subtype of {}", c.getName(), baseClass.getName());
                        break;
                    }
                }
            }
        }
    }
    mapper.registerSubtypes(toRegister.toArray(new NamedType[toRegister.size()]));
}
Also used : ConfigurationBuilder(org.reflections.util.ConfigurationBuilder) NamedType(org.nd4j.shade.jackson.databind.jsontype.NamedType) URL(java.net.URL) AnnotatedClass(org.nd4j.shade.jackson.databind.introspect.AnnotatedClass) FilterBuilder(org.reflections.util.FilterBuilder) DL4JSubTypesScanner(org.deeplearning4j.util.reflections.DL4JSubTypesScanner) AnnotatedClass(org.nd4j.shade.jackson.databind.introspect.AnnotatedClass) Reflections(org.reflections.Reflections)

Aggregations

URL (java.net.URL)1 DL4JSubTypesScanner (org.deeplearning4j.util.reflections.DL4JSubTypesScanner)1 AnnotatedClass (org.nd4j.shade.jackson.databind.introspect.AnnotatedClass)1 NamedType (org.nd4j.shade.jackson.databind.jsontype.NamedType)1 Reflections (org.reflections.Reflections)1 ConfigurationBuilder (org.reflections.util.ConfigurationBuilder)1 FilterBuilder (org.reflections.util.FilterBuilder)1