use of org.deeplearning4j.nn.conf.ComputationGraphConfiguration in project deeplearning4j by deeplearning4j.
the class KerasModelConfigurationTest method importKerasMlpModelMultilossConfigTest.
@Test
public void importKerasMlpModelMultilossConfigTest() throws Exception {
ClassPathResource configResource = new ClassPathResource("modelimport/keras/configs/mlp_fapi_multiloss_config.json", KerasModelConfigurationTest.class.getClassLoader());
ComputationGraphConfiguration config = new KerasModel.ModelBuilder().modelJsonInputStream(configResource.getInputStream()).enforceTrainingConfig(true).buildModel().getComputationGraphConfiguration();
ComputationGraph model = new ComputationGraph(config);
model.init();
}
use of org.deeplearning4j.nn.conf.ComputationGraphConfiguration in project deeplearning4j by deeplearning4j.
the class KerasModelConfigurationTest method importKerasMlpModelConfigTest.
@Test
public void importKerasMlpModelConfigTest() throws Exception {
ClassPathResource configResource = new ClassPathResource("modelimport/keras/configs/mlp_fapi_config.json", KerasModelConfigurationTest.class.getClassLoader());
ComputationGraphConfiguration config = new KerasModel.ModelBuilder().modelJsonInputStream(configResource.getInputStream()).enforceTrainingConfig(true).buildModel().getComputationGraphConfiguration();
ComputationGraph model = new ComputationGraph(config);
model.init();
}
use of org.deeplearning4j.nn.conf.ComputationGraphConfiguration in project deeplearning4j by deeplearning4j.
the class ModelSerializer method restoreComputationGraph.
/**
* Load a computation graph from a file
* @param file the file to get the computation graph from
* @return the loaded computation graph
*
* @throws IOException
*/
public static ComputationGraph restoreComputationGraph(@NonNull File file, boolean loadUpdater) throws IOException {
ZipFile zipFile = new ZipFile(file);
boolean gotConfig = false;
boolean gotCoefficients = false;
boolean gotOldUpdater = false;
boolean gotUpdaterState = false;
boolean gotPreProcessor = false;
String json = "";
INDArray params = null;
ComputationGraphUpdater updater = null;
INDArray updaterState = null;
DataSetPreProcessor preProcessor = null;
ZipEntry config = zipFile.getEntry("configuration.json");
if (config != null) {
//restoring configuration
InputStream stream = zipFile.getInputStream(config);
BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
String line = "";
StringBuilder js = new StringBuilder();
while ((line = reader.readLine()) != null) {
js.append(line).append("\n");
}
json = js.toString();
reader.close();
stream.close();
gotConfig = true;
}
ZipEntry coefficients = zipFile.getEntry("coefficients.bin");
if (coefficients != null) {
InputStream stream = zipFile.getInputStream(coefficients);
DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
params = Nd4j.read(dis);
dis.close();
gotCoefficients = true;
}
if (loadUpdater) {
ZipEntry oldUpdaters = zipFile.getEntry(OLD_UPDATER_BIN);
if (oldUpdaters != null) {
InputStream stream = zipFile.getInputStream(oldUpdaters);
ObjectInputStream ois = new ObjectInputStream(stream);
try {
updater = (ComputationGraphUpdater) ois.readObject();
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
gotOldUpdater = true;
}
ZipEntry updaterStateEntry = zipFile.getEntry(UPDATER_BIN);
if (updaterStateEntry != null) {
InputStream stream = zipFile.getInputStream(updaterStateEntry);
DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
updaterState = Nd4j.read(dis);
dis.close();
gotUpdaterState = true;
}
}
ZipEntry prep = zipFile.getEntry("preprocessor.bin");
if (prep != null) {
InputStream stream = zipFile.getInputStream(prep);
ObjectInputStream ois = new ObjectInputStream(stream);
try {
preProcessor = (DataSetPreProcessor) ois.readObject();
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
gotPreProcessor = true;
}
zipFile.close();
if (gotConfig && gotCoefficients) {
ComputationGraphConfiguration confFromJson = ComputationGraphConfiguration.fromJson(json);
ComputationGraph cg = new ComputationGraph(confFromJson);
cg.init(params, false);
if (gotUpdaterState && updaterState != null) {
cg.getUpdater().setStateViewArray(updaterState);
} else if (gotOldUpdater && updater != null) {
cg.setUpdater(updater);
}
return cg;
} else
throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]");
}
use of org.deeplearning4j.nn.conf.ComputationGraphConfiguration in project deeplearning4j by deeplearning4j.
the class BaseOptimizer method incrementIterationCount.
public static void incrementIterationCount(Model model, int incrementBy) {
if (model instanceof MultiLayerNetwork) {
MultiLayerConfiguration conf = ((MultiLayerNetwork) model).getLayerWiseConfigurations();
conf.setIterationCount(conf.getIterationCount() + incrementBy);
} else if (model instanceof ComputationGraph) {
ComputationGraphConfiguration conf = ((ComputationGraph) model).getConfiguration();
conf.setIterationCount(conf.getIterationCount() + incrementBy);
} else {
model.conf().setIterationCount(model.conf().getIterationCount() + incrementBy);
}
}
use of org.deeplearning4j.nn.conf.ComputationGraphConfiguration in project deeplearning4j by deeplearning4j.
the class ParameterAveragingTrainingMaster method processResults.
private void processResults(SparkDl4jMultiLayer network, SparkComputationGraph graph, JavaRDD<ParameterAveragingTrainingResult> results, int splitNum, int totalSplits) {
if (collectTrainingStats)
stats.logAggregateStartTime();
ParameterAveragingAggregationTuple tuple = results.aggregate(null, new ParameterAveragingElementAddFunction(), new ParameterAveragingElementCombineFunction());
INDArray params = tuple.getParametersSum();
int aggCount = tuple.getAggregationsCount();
SparkTrainingStats aggregatedStats = tuple.getSparkTrainingStats();
if (collectTrainingStats)
stats.logAggregationEndTime();
if (collectTrainingStats)
stats.logProcessParamsUpdaterStart();
if (params != null) {
params.divi(aggCount);
INDArray updaterState = tuple.getUpdaterStateSum();
if (updaterState != null)
//May be null if all SGD updaters, for example
updaterState.divi(aggCount);
if (network != null) {
MultiLayerNetwork net = network.getNetwork();
net.setParameters(params);
if (updaterState != null)
net.getUpdater().setStateViewArray(null, updaterState, false);
network.setScore(tuple.getScoreSum() / tuple.getAggregationsCount());
} else {
ComputationGraph g = graph.getNetwork();
g.setParams(params);
if (updaterState != null)
g.getUpdater().setStateViewArray(updaterState);
graph.setScore(tuple.getScoreSum() / tuple.getAggregationsCount());
}
} else {
log.info("Skipping imbalanced split with no data for all executors");
}
if (collectTrainingStats) {
stats.logProcessParamsUpdaterEnd();
stats.addWorkerStats(aggregatedStats);
}
if (statsStorage != null) {
Collection<StorageMetaData> meta = tuple.getListenerMetaData();
if (meta != null && meta.size() > 0) {
statsStorage.putStorageMetaData(meta);
}
Collection<Persistable> staticInfo = tuple.getListenerStaticInfo();
if (staticInfo != null && staticInfo.size() > 0) {
statsStorage.putStaticInfo(staticInfo);
}
Collection<Persistable> updates = tuple.getListenerUpdates();
if (updates != null && updates.size() > 0) {
statsStorage.putUpdate(updates);
}
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
log.info("Completed training of split {} of {}", splitNum, totalSplits);
if (params != null) {
//Params may be null for edge case (empty RDD)
if (network != null) {
MultiLayerConfiguration conf = network.getNetwork().getLayerWiseConfigurations();
int numUpdates = network.getNetwork().conf().getNumIterations() * averagingFrequency;
conf.setIterationCount(conf.getIterationCount() + numUpdates);
} else {
ComputationGraphConfiguration conf = graph.getNetwork().getConfiguration();
int numUpdates = graph.getNetwork().conf().getNumIterations() * averagingFrequency;
conf.setIterationCount(conf.getIterationCount() + numUpdates);
}
}
}
Aggregations