use of ml.shifu.shifu.container.ModelInitInputObject in project shifu by ShifuML.
the class NNTrainer method loadWeightsInput.
private void loadWeightsInput(int numWeights) {
try {
File file = new File("./init" + this.trainerID + ".json");
if (!file.exists()) {
ModelInitInputObject io = new ModelInitInputObject();
io.setWeights(randomSetWeights(numWeights));
io.setNumWeights(numWeights);
setWeights(io.getWeights());
JSONUtils.writeValue(file, io);
} else {
BufferedReader reader = ShifuFileUtils.getReader("./init" + this.trainerID + ".json", SourceType.LOCAL);
ModelInitInputObject io = JSONUtils.readValue(reader, ModelInitInputObject.class);
if (io == null) {
io = new ModelInitInputObject();
}
if (io.getNumWeights() != numWeights) {
io.setNumWeights(numWeights);
io.setWeights(randomSetWeights(numWeights));
JSONUtils.writeValue(file, io);
}
setWeights(io.getWeights());
reader.close();
}
} catch (IOException e) {
e.printStackTrace();
}
}
use of ml.shifu.shifu.container.ModelInitInputObject in project shifu by ShifuML.
the class AbstractTrainer method loadSampleInput.
/*
* load/save the sampling data from pre-initialization file
*/
private List<Integer> loadSampleInput(int sampleSize, int masterSize, boolean replaceable) throws IOException {
List<Integer> list = null;
File file = new File("./init" + trainerID + ".json");
if (!file.exists()) {
list = randomSetSampleIndex(sampleSize, masterSize, replaceable);
ModelInitInputObject io = new ModelInitInputObject();
io.setNumSample(sampleSize);
io.setSampleIndex(list);
JSONUtils.writeValue(file, io);
} else {
BufferedReader reader = ShifuFileUtils.getReader("./init" + trainerID + ".json", SourceType.LOCAL);
ModelInitInputObject io = JSONUtils.readValue(reader, ModelInitInputObject.class);
if (io == null) {
io = new ModelInitInputObject();
}
if (io.getNumSample() != sampleSize) {
list = randomSetSampleIndex(sampleSize, masterSize, replaceable);
io.setNumSample(sampleSize);
io.setSampleIndex(list);
JSONUtils.writeValue(file, io);
} else {
list = io.getSampleIndex();
}
reader.close();
}
return list;
}
Aggregations