use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.
the class KerasModelConfigurationTest method importKerasMlpModelConfigTest.
@Test
public void importKerasMlpModelConfigTest() throws Exception {
ClassPathResource configResource = new ClassPathResource("modelimport/keras/configs/mlp_fapi_config.json", KerasModelConfigurationTest.class.getClassLoader());
ComputationGraphConfiguration config = new KerasModel.ModelBuilder().modelJsonInputStream(configResource.getInputStream()).enforceTrainingConfig(true).buildModel().getComputationGraphConfiguration();
ComputationGraph model = new ComputationGraph(config);
model.init();
}
use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.
the class ParallelWrapper method averageUpdatersState.
private void averageUpdatersState(AtomicInteger locker, double score) {
if (averageUpdaters) {
ComputationGraphUpdater updater = ((ComputationGraph) model).getUpdater();
int batchSize = 0;
if (updater != null && updater.getStateViewArray() != null) {
if (!legacyAveraging || Nd4j.getAffinityManager().getNumberOfDevices() == 1) {
List<INDArray> updaters = new ArrayList<>();
for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
ComputationGraph workerModel = (ComputationGraph) zoo[cnt].getModel();
updaters.add(workerModel.getUpdater().getStateViewArray());
batchSize += workerModel.batchSize();
}
Nd4j.averageAndPropagate(updater.getStateViewArray(), updaters);
} else {
INDArray state = Nd4j.zeros(updater.getStateViewArray().shape());
int cnt = 0;
for (; cnt < workers && cnt < locker.get(); cnt++) {
ComputationGraph workerModel = (ComputationGraph) zoo[cnt].getModel();
state.addi(workerModel.getUpdater().getStateViewArray());
batchSize += workerModel.batchSize();
}
state.divi(cnt);
updater.setStateViewArray(state);
}
}
}
((ComputationGraph) model).setScore(score);
}
use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.
the class ParallelWrapper method fit.
/**
*
* @param source
*/
public synchronized void fit(@NonNull MultiDataSetIterator source) {
stopFit.set(false);
if (zoo == null) {
zoo = new Trainer[workers];
for (int cnt = 0; cnt < workers; cnt++) {
// we pass true here, to tell Trainer to use MultiDataSet queue for training
zoo[cnt] = new Trainer(cnt, model, Nd4j.getAffinityManager().getDeviceForCurrentThread(), true);
zoo[cnt].setUncaughtExceptionHandler(handler);
zoo[cnt].start();
}
} else {
for (int cnt = 0; cnt < workers; cnt++) {
zoo[cnt].useMDS = true;
}
}
source.reset();
MultiDataSetIterator iterator;
if (prefetchSize > 0 && source.asyncSupported()) {
iterator = new AsyncMultiDataSetIterator(source, prefetchSize);
} else
iterator = source;
AtomicInteger locker = new AtomicInteger(0);
while (iterator.hasNext() && !stopFit.get()) {
MultiDataSet dataSet = iterator.next();
if (dataSet == null)
throw new ND4JIllegalStateException("You can't have NULL as MultiDataSet");
/*
now dataSet should be dispatched to next free workers, until all workers are busy. And then we should block till all finished.
*/
int pos = locker.getAndIncrement();
zoo[pos].feedMultiDataSet(dataSet);
/*
if all workers are dispatched now, join till all are finished
*/
if (pos + 1 == workers || !iterator.hasNext()) {
iterationsCounter.incrementAndGet();
for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
try {
zoo[cnt].waitTillRunning();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
Nd4j.getMemoryManager().invokeGcOccasionally();
/*
average model, and propagate it to whole
*/
if (iterationsCounter.get() % averagingFrequency == 0 && pos + 1 == workers) {
double score = getScore(locker);
// averaging updaters state
if (model instanceof ComputationGraph) {
averageUpdatersState(locker, score);
} else
throw new RuntimeException("MultiDataSet must only be used with ComputationGraph model");
if (legacyAveraging && Nd4j.getAffinityManager().getNumberOfDevices() > 1) {
for (int cnt = 0; cnt < workers; cnt++) {
zoo[cnt].updateModel(model);
}
}
}
locker.set(0);
}
}
// sanity checks, or the dataset may never average
if (!wasAveraged)
log.warn("Parameters were never averaged on current fit(). Ratios of batch size, num workers, and averaging frequency may be responsible.");
// throw new IllegalStateException("Parameters were never averaged. Please check batch size ratios, number of workers, and your averaging frequency.");
log.debug("Iterations passed: {}", iterationsCounter.get());
// iterationsCounter.set(0);
}
use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.
the class ModelSerializer method writeModel.
/**
* Write a model to an output stream
* @param model the model to save
* @param stream the output stream to write to
* @param saveUpdater whether to save the updater for the model or not
* @throws IOException
*/
public static void writeModel(@NonNull Model model, @NonNull OutputStream stream, boolean saveUpdater) throws IOException {
ZipOutputStream zipfile = new ZipOutputStream(new CloseShieldOutputStream(stream));
// Save configuration as JSON
String json = "";
if (model instanceof MultiLayerNetwork) {
json = ((MultiLayerNetwork) model).getLayerWiseConfigurations().toJson();
} else if (model instanceof ComputationGraph) {
json = ((ComputationGraph) model).getConfiguration().toJson();
}
ZipEntry config = new ZipEntry("configuration.json");
zipfile.putNextEntry(config);
zipfile.write(json.getBytes());
// Save parameters as binary
ZipEntry coefficients = new ZipEntry("coefficients.bin");
zipfile.putNextEntry(coefficients);
DataOutputStream dos = new DataOutputStream(new BufferedOutputStream(zipfile));
try {
Nd4j.write(model.params(), dos);
} finally {
dos.flush();
if (!saveUpdater)
dos.close();
}
if (saveUpdater) {
INDArray updaterState = null;
if (model instanceof MultiLayerNetwork) {
updaterState = ((MultiLayerNetwork) model).getUpdater().getStateViewArray();
} else if (model instanceof ComputationGraph) {
updaterState = ((ComputationGraph) model).getUpdater().getStateViewArray();
}
if (updaterState != null && updaterState.length() > 0) {
ZipEntry updater = new ZipEntry(UPDATER_BIN);
zipfile.putNextEntry(updater);
try {
Nd4j.write(updaterState, dos);
} finally {
dos.flush();
dos.close();
}
}
}
zipfile.close();
}
use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.
the class ModelSerializer method restoreComputationGraph.
/**
* Load a computation graph from a file
* @param file the file to get the computation graph from
* @return the loaded computation graph
*
* @throws IOException
*/
public static ComputationGraph restoreComputationGraph(@NonNull File file, boolean loadUpdater) throws IOException {
ZipFile zipFile = new ZipFile(file);
boolean gotConfig = false;
boolean gotCoefficients = false;
boolean gotOldUpdater = false;
boolean gotUpdaterState = false;
boolean gotPreProcessor = false;
String json = "";
INDArray params = null;
ComputationGraphUpdater updater = null;
INDArray updaterState = null;
DataSetPreProcessor preProcessor = null;
ZipEntry config = zipFile.getEntry("configuration.json");
if (config != null) {
//restoring configuration
InputStream stream = zipFile.getInputStream(config);
BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
String line = "";
StringBuilder js = new StringBuilder();
while ((line = reader.readLine()) != null) {
js.append(line).append("\n");
}
json = js.toString();
reader.close();
stream.close();
gotConfig = true;
}
ZipEntry coefficients = zipFile.getEntry("coefficients.bin");
if (coefficients != null) {
InputStream stream = zipFile.getInputStream(coefficients);
DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
params = Nd4j.read(dis);
dis.close();
gotCoefficients = true;
}
if (loadUpdater) {
ZipEntry oldUpdaters = zipFile.getEntry(OLD_UPDATER_BIN);
if (oldUpdaters != null) {
InputStream stream = zipFile.getInputStream(oldUpdaters);
ObjectInputStream ois = new ObjectInputStream(stream);
try {
updater = (ComputationGraphUpdater) ois.readObject();
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
gotOldUpdater = true;
}
ZipEntry updaterStateEntry = zipFile.getEntry(UPDATER_BIN);
if (updaterStateEntry != null) {
InputStream stream = zipFile.getInputStream(updaterStateEntry);
DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
updaterState = Nd4j.read(dis);
dis.close();
gotUpdaterState = true;
}
}
ZipEntry prep = zipFile.getEntry("preprocessor.bin");
if (prep != null) {
InputStream stream = zipFile.getInputStream(prep);
ObjectInputStream ois = new ObjectInputStream(stream);
try {
preProcessor = (DataSetPreProcessor) ois.readObject();
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
gotPreProcessor = true;
}
zipFile.close();
if (gotConfig && gotCoefficients) {
ComputationGraphConfiguration confFromJson = ComputationGraphConfiguration.fromJson(json);
ComputationGraph cg = new ComputationGraph(confFromJson);
cg.init(params, false);
if (gotUpdaterState && updaterState != null) {
cg.getUpdater().setStateViewArray(updaterState);
} else if (gotOldUpdater && updater != null) {
cg.setUpdater(updater);
}
return cg;
} else
throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]");
}
Aggregations