use of org.deeplearning4j.nn.conf.ComputationGraphConfiguration in project deeplearning4j by deeplearning4j.
the class KerasModel method getComputationGraphConfiguration.
/**
* Configure a ComputationGraph from this Keras Model configuration.
*
* @return ComputationGraph
*/
public ComputationGraphConfiguration getComputationGraphConfiguration() throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
if (!this.className.equals(MODEL_CLASS_NAME_MODEL) && !this.className.equals(MODEL_CLASS_NAME_SEQUENTIAL))
throw new InvalidKerasConfigurationException("Keras model class name " + this.className + " incompatible with ComputationGraph");
NeuralNetConfiguration.Builder modelBuilder = new NeuralNetConfiguration.Builder();
ComputationGraphConfiguration.GraphBuilder graphBuilder = modelBuilder.graphBuilder();
/* Build String array of input layer names, add to ComputationGraph. */
String[] inputLayerNameArray = new String[this.inputLayerNames.size()];
this.inputLayerNames.toArray(inputLayerNameArray);
graphBuilder.addInputs(inputLayerNameArray);
/* Build InputType array of input layer types, add to ComputationGraph. */
List<InputType> inputTypeList = new ArrayList<InputType>();
for (String inputLayerName : this.inputLayerNames) inputTypeList.add(this.layers.get(inputLayerName).getOutputType());
InputType[] inputTypes = new InputType[inputTypeList.size()];
inputTypeList.toArray(inputTypes);
graphBuilder.setInputTypes(inputTypes);
/* Build String array of output layer names, add to ComputationGraph. */
String[] outputLayerNameArray = new String[this.outputLayerNames.size()];
this.outputLayerNames.toArray(outputLayerNameArray);
graphBuilder.setOutputs(outputLayerNameArray);
Map<String, InputPreProcessor> preprocessors = new HashMap<String, InputPreProcessor>();
/* Add layersOrdered one at a time. */
for (KerasLayer layer : this.layersOrdered) {
/* Get inbound layer names. */
List<String> inboundLayerNames = layer.getInboundLayerNames();
String[] inboundLayerNamesArray = new String[inboundLayerNames.size()];
inboundLayerNames.toArray(inboundLayerNamesArray);
/* Get inbound InputTypes and InputPreProcessor, if necessary. */
List<InputType> inboundTypeList = new ArrayList<InputType>();
for (String layerName : inboundLayerNames) inboundTypeList.add(this.outputTypes.get(layerName));
InputType[] inboundTypeArray = new InputType[inboundTypeList.size()];
inboundTypeList.toArray(inboundTypeArray);
InputPreProcessor preprocessor = layer.getInputPreprocessor(inboundTypeArray);
if (layer.usesRegularization())
modelBuilder.setUseRegularization(true);
if (layer.isLayer()) {
/* Add DL4J layer. */
if (preprocessor != null)
preprocessors.put(layer.getLayerName(), preprocessor);
graphBuilder.addLayer(layer.getLayerName(), layer.getLayer(), inboundLayerNamesArray);
if (this.outputLayerNames.contains(layer.getLayerName()) && !(layer.getLayer() instanceof IOutputLayer))
log.warn("Model cannot be trained: output layer " + layer.getLayerName() + " is not an IOutputLayer (no loss function specified)");
} else if (layer.isVertex()) {
/* Add DL4J vertex. */
if (preprocessor != null)
preprocessors.put(layer.getLayerName(), preprocessor);
graphBuilder.addVertex(layer.getLayerName(), layer.getVertex(), inboundLayerNamesArray);
if (this.outputLayerNames.contains(layer.getLayerName()) && !(layer.getVertex() instanceof IOutputLayer))
log.warn("Model cannot be trained: output vertex " + layer.getLayerName() + " is not an IOutputLayer (no loss function specified)");
} else if (layer.isInputPreProcessor()) {
if (preprocessor == null)
throw new UnsupportedKerasConfigurationException("Layer " + layer.getLayerName() + " could not be mapped to Layer, Vertex, or InputPreProcessor");
graphBuilder.addVertex(layer.getLayerName(), new PreprocessorVertex(preprocessor), inboundLayerNamesArray);
}
if (this.outputLayerNames.contains(layer.getLayerName()))
log.warn("Model cannot be trained: output " + layer.getLayerName() + " is not an IOutputLayer (no loss function specified)");
}
graphBuilder.setInputPreProcessors(preprocessors);
/* Whether to use standard backprop (or BPTT) or truncated BPTT. */
if (this.useTruncatedBPTT && this.truncatedBPTT > 0)
graphBuilder.backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(truncatedBPTT).tBPTTBackwardLength(truncatedBPTT);
else
graphBuilder.backpropType(BackpropType.Standard);
return graphBuilder.build();
}
use of org.deeplearning4j.nn.conf.ComputationGraphConfiguration in project deeplearning4j by deeplearning4j.
the class ModelGuesserTest method testModelGuessConfig.
@Test
public void testModelGuessConfig() throws Exception {
ClassPathResource resource = new ClassPathResource("modelimport/keras/configs/cnn_tf_config.json", ModelGuesserTest.class.getClassLoader());
File f = getTempFile(resource);
String configFilename = f.getAbsolutePath();
Object conf = ModelGuesser.loadConfigGuess(configFilename);
assertTrue(conf instanceof MultiLayerConfiguration);
ClassPathResource sequenceResource = new ClassPathResource("/keras/simple/mlp_fapi_multiloss_config.json");
File f2 = getTempFile(sequenceResource);
Object sequenceConf = ModelGuesser.loadConfigGuess(f2.getAbsolutePath());
assertTrue(sequenceConf instanceof ComputationGraphConfiguration);
ClassPathResource resourceDl4j = new ClassPathResource("model.json");
File fDl4j = getTempFile(resourceDl4j);
String configFilenameDl4j = fDl4j.getAbsolutePath();
Object confDl4j = ModelGuesser.loadConfigGuess(configFilenameDl4j);
assertTrue(confDl4j instanceof ComputationGraphConfiguration);
}
use of org.deeplearning4j.nn.conf.ComputationGraphConfiguration in project deeplearning4j by deeplearning4j.
the class ModelSerializerTest method testWriteCGModelInputStream.
@Test
public void testWriteCGModelInputStream() throws Exception {
ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(0.1).graphBuilder().addInputs("in").addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in").addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).build(), "dense").setOutputs("out").pretrain(false).backprop(true).build();
ComputationGraph cg = new ComputationGraph(config);
cg.init();
File tempFile = File.createTempFile("tsfs", "fdfsdf");
tempFile.deleteOnExit();
ModelSerializer.writeModel(cg, tempFile, true);
FileInputStream fis = new FileInputStream(tempFile);
ComputationGraph network = ModelSerializer.restoreComputationGraph(fis);
assertEquals(network.getConfiguration().toJson(), cg.getConfiguration().toJson());
assertEquals(cg.params(), network.params());
assertEquals(cg.getUpdater(), network.getUpdater());
}
use of org.deeplearning4j.nn.conf.ComputationGraphConfiguration in project deeplearning4j by deeplearning4j.
the class TrainModule method getConfig.
private Triple<MultiLayerConfiguration, ComputationGraphConfiguration, NeuralNetConfiguration> getConfig() {
boolean noData = currentSessionID == null;
StatsStorage ss = (noData ? null : knownSessionIDs.get(currentSessionID));
List<Persistable> allStatic = (noData ? Collections.EMPTY_LIST : ss.getAllStaticInfos(currentSessionID, StatsListener.TYPE_ID));
if (allStatic.size() == 0)
return null;
StatsInitializationReport p = (StatsInitializationReport) allStatic.get(0);
String modelClass = p.getModelClassName();
String config = p.getModelConfigJson();
if (modelClass.endsWith("MultiLayerNetwork")) {
MultiLayerConfiguration conf = MultiLayerConfiguration.fromJson(config);
return new Triple<>(conf, null, null);
} else if (modelClass.endsWith("ComputationGraph")) {
ComputationGraphConfiguration conf = ComputationGraphConfiguration.fromJson(config);
return new Triple<>(null, conf, null);
} else {
try {
NeuralNetConfiguration layer = NeuralNetConfiguration.mapper().readValue(config, NeuralNetConfiguration.class);
return new Triple<>(null, null, layer);
} catch (Exception e) {
e.printStackTrace();
}
}
return null;
}
use of org.deeplearning4j.nn.conf.ComputationGraphConfiguration in project deeplearning4j by deeplearning4j.
the class TrainModule method getLayerInfoTable.
private String[][] getLayerInfoTable(int layerIdx, TrainModuleUtils.GraphInfo gi, I18N i18N, boolean noData, StatsStorage ss, String wid) {
List<String[]> layerInfoRows = new ArrayList<>();
layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerName"), gi.getVertexNames().get(layerIdx) });
layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerType"), "" });
if (!noData) {
Persistable p = ss.getStaticInfo(currentSessionID, StatsListener.TYPE_ID, wid);
if (p != null) {
StatsInitializationReport initReport = (StatsInitializationReport) p;
String configJson = initReport.getModelConfigJson();
String modelClass = initReport.getModelClassName();
//TODO error handling...
String layerType = "";
Layer layer = null;
NeuralNetConfiguration nnc = null;
if (modelClass.endsWith("MultiLayerNetwork")) {
MultiLayerConfiguration conf = MultiLayerConfiguration.fromJson(configJson);
//-1 because of input
int confIdx = layerIdx - 1;
if (confIdx >= 0) {
nnc = conf.getConf(confIdx);
layer = nnc.getLayer();
} else {
//Input layer
layerType = "Input";
}
} else if (modelClass.endsWith("ComputationGraph")) {
ComputationGraphConfiguration conf = ComputationGraphConfiguration.fromJson(configJson);
String vertexName = gi.getVertexNames().get(layerIdx);
Map<String, GraphVertex> vertices = conf.getVertices();
if (vertices.containsKey(vertexName) && vertices.get(vertexName) instanceof LayerVertex) {
LayerVertex lv = (LayerVertex) vertices.get(vertexName);
nnc = lv.getLayerConf();
layer = nnc.getLayer();
} else if (conf.getNetworkInputs().contains(vertexName)) {
layerType = "Input";
} else {
GraphVertex gv = conf.getVertices().get(vertexName);
if (gv != null) {
layerType = gv.getClass().getSimpleName();
}
}
} else if (modelClass.endsWith("VariationalAutoencoder")) {
layerType = gi.getVertexTypes().get(layerIdx);
Map<String, String> map = gi.getVertexInfo().get(layerIdx);
for (Map.Entry<String, String> entry : map.entrySet()) {
layerInfoRows.add(new String[] { entry.getKey(), entry.getValue() });
}
}
if (layer != null) {
layerType = getLayerType(layer);
}
if (layer != null) {
String activationFn = null;
if (layer instanceof FeedForwardLayer) {
FeedForwardLayer ffl = (FeedForwardLayer) layer;
layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerNIn"), String.valueOf(ffl.getNIn()) });
layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerSize"), String.valueOf(ffl.getNOut()) });
activationFn = layer.getActivationFn().toString();
}
int nParams = layer.initializer().numParams(nnc);
layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerNParams"), String.valueOf(nParams) });
if (nParams > 0) {
WeightInit wi = layer.getWeightInit();
String str = wi.toString();
if (wi == WeightInit.DISTRIBUTION) {
str += layer.getDist();
}
layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerWeightInit"), str });
Updater u = layer.getUpdater();
String us = (u == null ? "" : u.toString());
layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerUpdater"), us });
//TODO: Maybe L1/L2, dropout, updater-specific values etc
}
if (layer instanceof ConvolutionLayer || layer instanceof SubsamplingLayer) {
int[] kernel;
int[] stride;
int[] padding;
if (layer instanceof ConvolutionLayer) {
ConvolutionLayer cl = (ConvolutionLayer) layer;
kernel = cl.getKernelSize();
stride = cl.getStride();
padding = cl.getPadding();
} else {
SubsamplingLayer ssl = (SubsamplingLayer) layer;
kernel = ssl.getKernelSize();
stride = ssl.getStride();
padding = ssl.getPadding();
activationFn = null;
layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerSubsamplingPoolingType"), ssl.getPoolingType().toString() });
}
layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerCnnKernel"), Arrays.toString(kernel) });
layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerCnnStride"), Arrays.toString(stride) });
layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerCnnPadding"), Arrays.toString(padding) });
}
if (activationFn != null) {
layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerActivationFn"), activationFn });
}
}
layerInfoRows.get(1)[1] = layerType;
}
}
return layerInfoRows.toArray(new String[layerInfoRows.size()][0]);
}
Aggregations