use of org.nd4j.linalg.dataset.api.DataSetPreProcessor in project deeplearning4j by deeplearning4j.
the class ModelSerializer method restoreComputationGraph.
/**
* Load a computation graph from a file
* @param file the file to get the computation graph from
* @return the loaded computation graph
*
* @throws IOException
*/
public static ComputationGraph restoreComputationGraph(@NonNull File file, boolean loadUpdater) throws IOException {
ZipFile zipFile = new ZipFile(file);
boolean gotConfig = false;
boolean gotCoefficients = false;
boolean gotOldUpdater = false;
boolean gotUpdaterState = false;
boolean gotPreProcessor = false;
String json = "";
INDArray params = null;
ComputationGraphUpdater updater = null;
INDArray updaterState = null;
DataSetPreProcessor preProcessor = null;
ZipEntry config = zipFile.getEntry("configuration.json");
if (config != null) {
//restoring configuration
InputStream stream = zipFile.getInputStream(config);
BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
String line = "";
StringBuilder js = new StringBuilder();
while ((line = reader.readLine()) != null) {
js.append(line).append("\n");
}
json = js.toString();
reader.close();
stream.close();
gotConfig = true;
}
ZipEntry coefficients = zipFile.getEntry("coefficients.bin");
if (coefficients != null) {
InputStream stream = zipFile.getInputStream(coefficients);
DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
params = Nd4j.read(dis);
dis.close();
gotCoefficients = true;
}
if (loadUpdater) {
ZipEntry oldUpdaters = zipFile.getEntry(OLD_UPDATER_BIN);
if (oldUpdaters != null) {
InputStream stream = zipFile.getInputStream(oldUpdaters);
ObjectInputStream ois = new ObjectInputStream(stream);
try {
updater = (ComputationGraphUpdater) ois.readObject();
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
gotOldUpdater = true;
}
ZipEntry updaterStateEntry = zipFile.getEntry(UPDATER_BIN);
if (updaterStateEntry != null) {
InputStream stream = zipFile.getInputStream(updaterStateEntry);
DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
updaterState = Nd4j.read(dis);
dis.close();
gotUpdaterState = true;
}
}
ZipEntry prep = zipFile.getEntry("preprocessor.bin");
if (prep != null) {
InputStream stream = zipFile.getInputStream(prep);
ObjectInputStream ois = new ObjectInputStream(stream);
try {
preProcessor = (DataSetPreProcessor) ois.readObject();
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
gotPreProcessor = true;
}
zipFile.close();
if (gotConfig && gotCoefficients) {
ComputationGraphConfiguration confFromJson = ComputationGraphConfiguration.fromJson(json);
ComputationGraph cg = new ComputationGraph(confFromJson);
cg.init(params, false);
if (gotUpdaterState && updaterState != null) {
cg.getUpdater().setStateViewArray(updaterState);
} else if (gotOldUpdater && updater != null) {
cg.setUpdater(updater);
}
return cg;
} else
throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]");
}
use of org.nd4j.linalg.dataset.api.DataSetPreProcessor in project deeplearning4j by deeplearning4j.
the class ModelSerializer method restoreMultiLayerNetwork.
/**
* Load a multi layer network from a file
*
* @param file the file to load from
* @return the loaded multi layer network
* @throws IOException
*/
public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull File file, boolean loadUpdater) throws IOException {
ZipFile zipFile = new ZipFile(file);
boolean gotConfig = false;
boolean gotCoefficients = false;
boolean gotOldUpdater = false;
boolean gotUpdaterState = false;
boolean gotPreProcessor = false;
String json = "";
INDArray params = null;
Updater updater = null;
INDArray updaterState = null;
DataSetPreProcessor preProcessor = null;
ZipEntry config = zipFile.getEntry("configuration.json");
if (config != null) {
//restoring configuration
InputStream stream = zipFile.getInputStream(config);
BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
String line = "";
StringBuilder js = new StringBuilder();
while ((line = reader.readLine()) != null) {
js.append(line).append("\n");
}
json = js.toString();
reader.close();
stream.close();
gotConfig = true;
}
ZipEntry coefficients = zipFile.getEntry("coefficients.bin");
if (coefficients != null) {
InputStream stream = zipFile.getInputStream(coefficients);
DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
params = Nd4j.read(dis);
dis.close();
gotCoefficients = true;
}
if (loadUpdater) {
//This can be removed a few releases after 0.4.1...
ZipEntry oldUpdaters = zipFile.getEntry(OLD_UPDATER_BIN);
if (oldUpdaters != null) {
InputStream stream = zipFile.getInputStream(oldUpdaters);
ObjectInputStream ois = new ObjectInputStream(stream);
try {
updater = (Updater) ois.readObject();
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
gotOldUpdater = true;
}
ZipEntry updaterStateEntry = zipFile.getEntry(UPDATER_BIN);
if (updaterStateEntry != null) {
InputStream stream = zipFile.getInputStream(updaterStateEntry);
DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
updaterState = Nd4j.read(dis);
dis.close();
gotUpdaterState = true;
}
}
ZipEntry prep = zipFile.getEntry("preprocessor.bin");
if (prep != null) {
InputStream stream = zipFile.getInputStream(prep);
ObjectInputStream ois = new ObjectInputStream(stream);
try {
preProcessor = (DataSetPreProcessor) ois.readObject();
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
gotPreProcessor = true;
}
zipFile.close();
if (gotConfig && gotCoefficients) {
MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(json);
MultiLayerNetwork network = new MultiLayerNetwork(confFromJson);
network.init(params, false);
if (gotUpdaterState && updaterState != null) {
network.getUpdater().setStateViewArray(network, updaterState, false);
} else if (gotOldUpdater && updater != null) {
network.setUpdater(updater);
}
return network;
} else
throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]");
}
Aggregations