use of org.encog.neural.flat.FlatNetwork in project shifu by ShifuML.
the class PersistBasicFloatNetwork method save.
/**
* {@inheritDoc}
*/
@Override
public final void save(final OutputStream os, final Object obj) {
final EncogWriteHelper out = new EncogWriteHelper(os);
final BasicFloatNetwork net = (BasicFloatNetwork) obj;
final FlatNetwork flat = net.getStructure().getFlat();
out.addSection("BASIC");
out.addSubSection("PARAMS");
out.addProperties(net.getProperties());
out.addSubSection("NETWORK");
out.writeProperty(BasicNetwork.TAG_BEGIN_TRAINING, flat.getBeginTraining());
out.writeProperty(BasicNetwork.TAG_CONNECTION_LIMIT, flat.getConnectionLimit());
out.writeProperty(BasicNetwork.TAG_CONTEXT_TARGET_OFFSET, flat.getContextTargetOffset());
out.writeProperty(BasicNetwork.TAG_CONTEXT_TARGET_SIZE, flat.getContextTargetSize());
out.writeProperty(BasicNetwork.TAG_END_TRAINING, flat.getEndTraining());
out.writeProperty(BasicNetwork.TAG_HAS_CONTEXT, flat.getHasContext());
out.writeProperty(PersistConst.INPUT_COUNT, flat.getInputCount());
out.writeProperty(BasicNetwork.TAG_LAYER_COUNTS, flat.getLayerCounts());
out.writeProperty(BasicNetwork.TAG_LAYER_FEED_COUNTS, flat.getLayerFeedCounts());
out.writeProperty(BasicNetwork.TAG_LAYER_CONTEXT_COUNT, flat.getLayerContextCount());
out.writeProperty(BasicNetwork.TAG_LAYER_INDEX, flat.getLayerIndex());
out.writeProperty(PersistConst.OUTPUT, flat.getLayerOutput());
out.writeProperty(PersistConst.OUTPUT_COUNT, flat.getOutputCount());
out.writeProperty(BasicNetwork.TAG_WEIGHT_INDEX, flat.getWeightIndex());
out.writeProperty(PersistConst.WEIGHTS, flat.getWeights());
out.writeProperty(BasicNetwork.TAG_BIAS_ACTIVATION, flat.getBiasActivation());
out.addSubSection("ACTIVATION");
for (final ActivationFunction af : flat.getActivationFunctions()) {
out.addColumn(af.getClass().getSimpleName());
for (int i = 0; i < af.getParams().length; i++) {
out.addColumn(af.getParams()[i]);
}
out.writeLine();
}
out.addSubSection("SUBSET");
Set<Integer> featureList = net.getFeatureSet();
if (featureList == null || featureList.size() == 0) {
out.writeProperty("SUBSETFEATURES", "");
} else {
String subFeaturesStr = StringUtils.join(featureList, ",");
out.writeProperty("SUBSETFEATURES", subFeaturesStr);
}
out.flush();
}
use of org.encog.neural.flat.FlatNetwork in project shifu by ShifuML.
the class PersistBasicFloatNetwork method read.
/**
* {@inheritDoc}
*/
@Override
public final Object read(final InputStream is) {
final BasicFloatNetwork result = new BasicFloatNetwork();
final FlatNetwork flat = new FlatNetwork();
final EncogReadHelper in = new EncogReadHelper(is);
EncogFileSection section;
while ((section = in.readNextSection()) != null) {
if (section.getSectionName().equals("BASIC") && section.getSubSectionName().equals("PARAMS")) {
final Map<String, String> params = section.parseParams();
result.getProperties().putAll(params);
}
if (section.getSectionName().equals("BASIC") && section.getSubSectionName().equals("NETWORK")) {
final Map<String, String> params = section.parseParams();
flat.setBeginTraining(EncogFileSection.parseInt(params, BasicNetwork.TAG_BEGIN_TRAINING));
flat.setConnectionLimit(EncogFileSection.parseDouble(params, BasicNetwork.TAG_CONNECTION_LIMIT));
flat.setContextTargetOffset(EncogFileSection.parseIntArray(params, BasicNetwork.TAG_CONTEXT_TARGET_OFFSET));
flat.setContextTargetSize(EncogFileSection.parseIntArray(params, BasicNetwork.TAG_CONTEXT_TARGET_SIZE));
flat.setEndTraining(EncogFileSection.parseInt(params, BasicNetwork.TAG_END_TRAINING));
flat.setHasContext(EncogFileSection.parseBoolean(params, BasicNetwork.TAG_HAS_CONTEXT));
flat.setInputCount(EncogFileSection.parseInt(params, PersistConst.INPUT_COUNT));
flat.setLayerCounts(EncogFileSection.parseIntArray(params, BasicNetwork.TAG_LAYER_COUNTS));
flat.setLayerFeedCounts(EncogFileSection.parseIntArray(params, BasicNetwork.TAG_LAYER_FEED_COUNTS));
flat.setLayerContextCount(EncogFileSection.parseIntArray(params, BasicNetwork.TAG_LAYER_CONTEXT_COUNT));
flat.setLayerIndex(EncogFileSection.parseIntArray(params, BasicNetwork.TAG_LAYER_INDEX));
flat.setLayerOutput(EncogFileSection.parseDoubleArray(params, PersistConst.OUTPUT));
flat.setLayerSums(new double[flat.getLayerOutput().length]);
flat.setOutputCount(EncogFileSection.parseInt(params, PersistConst.OUTPUT_COUNT));
flat.setWeightIndex(EncogFileSection.parseIntArray(params, BasicNetwork.TAG_WEIGHT_INDEX));
flat.setWeights(EncogFileSection.parseDoubleArray(params, PersistConst.WEIGHTS));
flat.setBiasActivation(EncogFileSection.parseDoubleArray(params, BasicNetwork.TAG_BIAS_ACTIVATION));
} else if (section.getSectionName().equals("BASIC") && section.getSubSectionName().equals("ACTIVATION")) {
int index = 0;
flat.setActivationFunctions(new ActivationFunction[flat.getLayerCounts().length]);
for (final String line : section.getLines()) {
ActivationFunction af = null;
final List<String> cols = EncogFileSection.splitColumns(line);
String name = "org.encog.engine.network.activation." + cols.get(0);
if (cols.get(0).equals("ActivationReLU")) {
name = "ml.shifu.shifu.core.dtrain.nn.ActivationReLU";
} else if (cols.get(0).equals("ActivationLeakyReLU")) {
name = "ml.shifu.shifu.core.dtrain.nn.ActivationLeakyReLU";
} else if (cols.get(0).equals("ActivationSwish")) {
name = "ml.shifu.shifu.core.dtrain.nn.ActivationSwish";
} else if (cols.get(0).equals("ActivationPTANH")) {
name = "ml.shifu.shifu.core.dtrain.nn.ActivationPTANH";
}
try {
final Class<?> clazz = Class.forName(name);
af = (ActivationFunction) clazz.newInstance();
} catch (final ClassNotFoundException e) {
throw new PersistError(e);
} catch (final InstantiationException e) {
throw new PersistError(e);
} catch (final IllegalAccessException e) {
throw new PersistError(e);
}
for (int i = 0; i < af.getParamNames().length; i++) {
af.setParam(i, CSVFormat.EG_FORMAT.parse(cols.get(i + 1)));
}
flat.getActivationFunctions()[index++] = af;
}
} else if (section.getSectionName().equals("BASIC") && section.getSubSectionName().equals("SUBSET")) {
final Map<String, String> params = section.parseParams();
String subsetStr = params.get("SUBSETFEATURES");
if (StringUtils.isBlank(subsetStr)) {
result.setFeatureSet(null);
} else {
String[] splits = subsetStr.split(",");
Set<Integer> subFeatures = new HashSet<Integer>();
for (String split : splits) {
int featureIndex = Integer.parseInt(split);
subFeatures.add(featureIndex);
}
result.setFeatureSet(subFeatures);
}
}
}
result.getStructure().setFlat(flat);
return result;
}
use of org.encog.neural.flat.FlatNetwork in project shifu by ShifuML.
the class PersistBasicFloatNetwork method readNetwork.
public BasicFloatNetwork readNetwork(final DataInput in) throws IOException {
final BasicFloatNetwork result = new BasicFloatNetwork();
final FlatNetwork flat = new FlatNetwork();
// read properties
Map<String, String> properties = new HashMap<String, String>();
int size = in.readInt();
for (int i = 0; i < size; i++) {
properties.put(ml.shifu.shifu.core.dtrain.StringUtils.readString(in), ml.shifu.shifu.core.dtrain.StringUtils.readString(in));
}
result.getProperties().putAll(properties);
// read fields
flat.setBeginTraining(in.readInt());
flat.setConnectionLimit(in.readDouble());
flat.setContextTargetOffset(readIntArray(in));
flat.setContextTargetSize(readIntArray(in));
flat.setEndTraining(in.readInt());
flat.setHasContext(in.readBoolean());
flat.setInputCount(in.readInt());
flat.setLayerCounts(readIntArray(in));
flat.setLayerFeedCounts(readIntArray(in));
flat.setLayerContextCount(readIntArray(in));
flat.setLayerIndex(readIntArray(in));
flat.setLayerOutput(readDoubleArray(in));
flat.setOutputCount(in.readInt());
flat.setLayerSums(new double[flat.getLayerOutput().length]);
flat.setWeightIndex(readIntArray(in));
flat.setWeights(readDoubleArray(in));
flat.setBiasActivation(readDoubleArray(in));
// read activations
flat.setActivationFunctions(new ActivationFunction[flat.getLayerCounts().length]);
int acSize = in.readInt();
for (int i = 0; i < acSize; i++) {
String name = ml.shifu.shifu.core.dtrain.StringUtils.readString(in);
if (name.equals("ActivationReLU")) {
name = ActivationReLU.class.getName();
} else if (name.equals("ActivationLeakyReLU")) {
name = ActivationLeakyReLU.class.getName();
} else if (name.equals("ActivationSwish")) {
name = ActivationSwish.class.getName();
} else if (name.equals("ActivationPTANH")) {
name = ActivationPTANH.class.getName();
} else {
name = "org.encog.engine.network.activation." + name;
}
ActivationFunction af = null;
try {
final Class<?> clazz = Class.forName(name);
af = (ActivationFunction) clazz.newInstance();
} catch (final ClassNotFoundException e) {
throw new PersistError(e);
} catch (final InstantiationException e) {
throw new PersistError(e);
} catch (final IllegalAccessException e) {
throw new PersistError(e);
}
double[] params = readDoubleArray(in);
for (int j = 0; j < params.length; j++) {
af.setParam(j, params[j]);
}
flat.getActivationFunctions()[i] = af;
}
// read subset
int subsetSize = in.readInt();
Set<Integer> featureList = new HashSet<Integer>();
for (int i = 0; i < subsetSize; i++) {
featureList.add(in.readInt());
}
result.setFeatureSet(featureList);
result.getStructure().setFlat(flat);
return result;
}
use of org.encog.neural.flat.FlatNetwork in project shifu by ShifuML.
the class NNModelSpecTest method testModelStructureCompare.
@Test
public void testModelStructureCompare() {
BasicML basicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model0.nn")));
BasicNetwork basicNetwork = (BasicNetwork) basicML;
FlatNetwork flatNetwork = basicNetwork.getFlat();
BasicML extendedBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model1.nn")));
BasicNetwork extendedBasicNetwork = (BasicNetwork) extendedBasicML;
FlatNetwork extendedFlatNetwork = extendedBasicNetwork.getFlat();
Assert.assertEquals(new NNStructureComparator().compare(flatNetwork, extendedFlatNetwork), -1);
Assert.assertEquals(new NNStructureComparator().compare(flatNetwork, flatNetwork), 0);
Assert.assertEquals(new NNStructureComparator().compare(extendedFlatNetwork, flatNetwork), 1);
BasicML diffBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model2.nn")));
BasicNetwork diffBasicNetwork = (BasicNetwork) diffBasicML;
FlatNetwork diffFlatNetwork = diffBasicNetwork.getFlat();
Assert.assertEquals(new NNStructureComparator().compare(flatNetwork, diffFlatNetwork), -1);
Assert.assertEquals(new NNStructureComparator().compare(diffFlatNetwork, flatNetwork), -1);
Assert.assertEquals(new NNStructureComparator().compare(extendedFlatNetwork, diffFlatNetwork), 1);
Assert.assertEquals(new NNStructureComparator().compare(diffFlatNetwork, extendedFlatNetwork), -1);
BasicML deepBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model3.nn")));
BasicNetwork deppBasicNetwork = (BasicNetwork) deepBasicML;
FlatNetwork deepFlatNetwork = deppBasicNetwork.getFlat();
Assert.assertEquals(new NNStructureComparator().compare(deepFlatNetwork, flatNetwork), 1);
Assert.assertEquals(new NNStructureComparator().compare(flatNetwork, deepFlatNetwork), -1);
}
use of org.encog.neural.flat.FlatNetwork in project shifu by ShifuML.
the class NNModelSpecTest method testFitExistingModelIn.
@Test
public void testFitExistingModelIn() {
BasicML basicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model0.nn")));
BasicNetwork basicNetwork = (BasicNetwork) basicML;
FlatNetwork flatNetwork = basicNetwork.getFlat();
NNMaster master = new NNMaster();
Set<Integer> fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, flatNetwork, Arrays.asList(new Integer[] { 6 }));
List<Integer> indexList = new ArrayList<Integer>(fixedWeightIndexSet);
Collections.sort(indexList);
Assert.assertEquals(indexList.size(), 31);
fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, flatNetwork, Arrays.asList(new Integer[] { 1 }));
indexList = new ArrayList<Integer>(fixedWeightIndexSet);
Collections.sort(indexList);
Assert.assertEquals(indexList.size(), 930);
BasicML extendedBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model1.nn")));
BasicNetwork extendedBasicNetwork = (BasicNetwork) extendedBasicML;
FlatNetwork extendedFlatNetwork = extendedBasicNetwork.getFlat();
fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, extendedFlatNetwork, Arrays.asList(new Integer[] { 1 }));
indexList = new ArrayList<Integer>(fixedWeightIndexSet);
Collections.sort(indexList);
Assert.assertEquals(indexList.size(), 930);
fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, extendedFlatNetwork, Arrays.asList(new Integer[] { 1 }), false);
indexList = new ArrayList<Integer>(fixedWeightIndexSet);
Collections.sort(indexList);
Assert.assertEquals(indexList.size(), 900);
}
Aggregations