use of org.deeplearning4j.util.reflections.DL4JSubTypesScanner in project deeplearning4j by deeplearning4j.
the class NeuralNetConfiguration method registerSubtypes.
private static synchronized void registerSubtypes(ObjectMapper mapper) {
//Register concrete subtypes for JSON serialization
List<Class<?>> classes = Arrays.<Class<?>>asList(InputPreProcessor.class, ILossFunction.class, IActivation.class, Layer.class, GraphVertex.class, ReconstructionDistribution.class);
List<String> classNames = new ArrayList<>(6);
for (Class<?> c : classes) classNames.add(c.getName());
// First: scan the classpath and find all instances of the 'baseClasses' classes
if (subtypesClassCache == null) {
//Check system property:
String prop = System.getProperty(CUSTOM_FUNCTIONALITY);
if (prop != null && !Boolean.parseBoolean(prop)) {
subtypesClassCache = Collections.emptySet();
} else {
List<Class<?>> interfaces = Arrays.<Class<?>>asList(InputPreProcessor.class, ILossFunction.class, IActivation.class, ReconstructionDistribution.class);
List<Class<?>> classesList = Arrays.<Class<?>>asList(Layer.class, GraphVertex.class);
Collection<URL> urls = ClasspathHelper.forClassLoader();
List<URL> scanUrls = new ArrayList<>();
for (URL u : urls) {
String path = u.getPath();
if (!path.matches(".*/jre/lib/.*jar")) {
//Skip JRE/JDK JARs
scanUrls.add(u);
}
}
Reflections reflections = new Reflections(new ConfigurationBuilder().filterInputsBy(new FilterBuilder().exclude(//Consider only .class files (to avoid debug messages etc. on .dlls, etc
"^(?!.*\\.class$).*$").exclude("^org.nd4j.*").exclude("^org.datavec.*").exclude(//JavaCPP
"^org.bytedeco.*").exclude(//Jackson
"^com.fasterxml.*").exclude(//Apache commons, Spark, log4j etc
"^org.apache.*").exclude("^org.projectlombok.*").exclude("^com.twelvemonkeys.*").exclude("^org.joda.*").exclude("^org.slf4j.*").exclude("^com.google.*").exclude("^org.reflections.*").exclude(//Logback
"^ch.qos.*")).addUrls(scanUrls).setScanners(new DL4JSubTypesScanner(interfaces, classesList)));
org.reflections.Store store = reflections.getStore();
Iterable<String> subtypesByName = store.getAll(DL4JSubTypesScanner.class.getSimpleName(), classNames);
Set<? extends Class<?>> subtypeClasses = Sets.newHashSet(ReflectionUtils.forNames(subtypesByName));
subtypesClassCache = new HashSet<>();
for (Class<?> c : subtypeClasses) {
if (Modifier.isAbstract(c.getModifiers()) || Modifier.isInterface(c.getModifiers())) {
//log.info("Skipping abstract/interface: {}",c);
continue;
}
subtypesClassCache.add(c);
}
}
}
//Second: get all currently registered subtypes for this mapper
Set<Class<?>> registeredSubtypes = new HashSet<>();
for (Class<?> c : classes) {
AnnotatedClass ac = AnnotatedClass.construct(c, mapper.getSerializationConfig().getAnnotationIntrospector(), null);
Collection<NamedType> types = mapper.getSubtypeResolver().collectAndResolveSubtypes(ac, mapper.getSerializationConfig(), mapper.getSerializationConfig().getAnnotationIntrospector());
for (NamedType nt : types) {
registeredSubtypes.add(nt.getType());
}
}
//Third: register all _concrete_ subtypes that are not already registered
List<NamedType> toRegister = new ArrayList<>();
for (Class<?> c : subtypesClassCache) {
//Check if it's concrete or abstract...
if (Modifier.isAbstract(c.getModifiers()) || Modifier.isInterface(c.getModifiers())) {
//log.info("Skipping abstract/interface: {}",c);
continue;
}
if (!registeredSubtypes.contains(c)) {
String name;
if (ClassUtils.isInnerClass(c)) {
Class<?> c2 = c.getDeclaringClass();
name = c2.getSimpleName() + "$" + c.getSimpleName();
} else {
name = c.getSimpleName();
}
toRegister.add(new NamedType(c, name));
if (log.isDebugEnabled()) {
for (Class<?> baseClass : classes) {
if (baseClass.isAssignableFrom(c)) {
log.debug("Registering class for JSON serialization: {} as subtype of {}", c.getName(), baseClass.getName());
break;
}
}
}
}
}
mapper.registerSubtypes(toRegister.toArray(new NamedType[toRegister.size()]));
}
Aggregations