use of org.encog.neural.flat.FlatNetwork in project shifu by ShifuML.
the class PersistBasicFloatNetwork method saveNetwork.
public void saveNetwork(DataOutput out, final BasicFloatNetwork network) throws IOException {
final FlatNetwork flat = network.getStructure().getFlat();
// write general properties
Map<String, String> properties = network.getProperties();
if (properties == null) {
out.writeInt(0);
} else {
out.writeInt(properties.size());
for (Entry<String, String> entry : properties.entrySet()) {
ml.shifu.shifu.core.dtrain.StringUtils.writeString(out, entry.getKey());
ml.shifu.shifu.core.dtrain.StringUtils.writeString(out, entry.getValue());
}
}
// write fields values in BasicFloatNetwork
out.writeInt(flat.getBeginTraining());
out.writeDouble(flat.getConnectionLimit());
writeIntArray(out, flat.getContextTargetOffset());
writeIntArray(out, flat.getContextTargetSize());
out.writeInt(flat.getEndTraining());
out.writeBoolean(flat.getHasContext());
out.writeInt(flat.getInputCount());
writeIntArray(out, flat.getLayerCounts());
writeIntArray(out, flat.getLayerFeedCounts());
writeIntArray(out, flat.getLayerContextCount());
writeIntArray(out, flat.getLayerIndex());
writeDoubleArray(out, flat.getLayerOutput());
out.writeInt(flat.getOutputCount());
writeIntArray(out, flat.getWeightIndex());
writeDoubleArray(out, flat.getWeights());
writeDoubleArray(out, flat.getBiasActivation());
// write activation list
out.writeInt(flat.getActivationFunctions().length);
for (final ActivationFunction af : flat.getActivationFunctions()) {
ml.shifu.shifu.core.dtrain.StringUtils.writeString(out, af.getClass().getSimpleName());
writeDoubleArray(out, af.getParams());
}
// write sub sets
Set<Integer> featureList = network.getFeatureSet();
if (featureList == null || featureList.size() == 0) {
out.writeInt(0);
} else {
out.writeInt(featureList.size());
for (Integer integer : featureList) {
out.writeInt(integer);
}
}
}
use of org.encog.neural.flat.FlatNetwork in project shifu by ShifuML.
the class AbstractNNWorker method initGradient.
@SuppressWarnings("unchecked")
private void initGradient(FloatMLDataSet training, FloatMLDataSet testing, double[] weights, boolean isCrossOver) {
int numLayers = (Integer) this.validParams.get(CommonConstants.NUM_HIDDEN_LAYERS);
List<String> actFunc = (List<String>) this.validParams.get(CommonConstants.ACTIVATION_FUNC);
List<Integer> hiddenNodeList = (List<Integer>) this.validParams.get(CommonConstants.NUM_HIDDEN_NODES);
String outputActivationFunc = (String) validParams.get(CommonConstants.OUTPUT_ACTIVATION_FUNC);
BasicNetwork network = DTrainUtils.generateNetwork(this.featureInputsCnt, this.outputNodeCount, numLayers, actFunc, hiddenNodeList, false, this.dropoutRate, this.wgtInit, CommonUtils.isLinearTarget(modelConfig, columnConfigList), outputActivationFunc);
// use the weights from master
network.getFlat().setWeights(weights);
FlatNetwork flat = network.getFlat();
// copy Propagation from encog, fix flat spot problem
double[] flatSpot = new double[flat.getActivationFunctions().length];
for (int i = 0; i < flat.getActivationFunctions().length; i++) {
flatSpot[i] = flat.getActivationFunctions()[i] instanceof ActivationSigmoid ? 0.1 : 0.0;
}
LOG.info("Gradient computing thread count is {}.", modelConfig.getTrain().getWorkerThreadCount());
this.gradient = new ParallelGradient((FloatFlatNetwork) flat, training, testing, flatSpot, new LinearErrorFunction(), isCrossOver, modelConfig.getTrain().getWorkerThreadCount(), this.lossStr, this.batchs);
}
use of org.encog.neural.flat.FlatNetwork in project shifu by ShifuML.
the class DTrainTest method initGradient.
public Gradient initGradient(MLDataSet training) {
FlatNetwork flat = network.getFlat().clone();
// copy Propagation from encog
double[] flatSpot = new double[flat.getActivationFunctions().length];
for (int i = 0; i < flat.getActivationFunctions().length; i++) {
final ActivationFunction af = flat.getActivationFunctions()[i];
if (af instanceof ActivationSigmoid) {
flatSpot[i] = 0.1;
} else {
flatSpot[i] = 0.0;
}
}
return new Gradient(flat, training.openAdditional(), training, flatSpot, new LinearErrorFunction(), false);
}
use of org.encog.neural.flat.FlatNetwork in project shifu by ShifuML.
the class NNModelSpecTest method testModelTraverse.
// @Test
public void testModelTraverse() {
BasicML basicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model0.nn")));
BasicNetwork basicNetwork = (BasicNetwork) basicML;
FlatNetwork flatNetwork = basicNetwork.getFlat();
BasicML extendedBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model1.nn")));
BasicNetwork extendedBasicNetwork = (BasicNetwork) extendedBasicML;
FlatNetwork extendedFlatNetwork = extendedBasicNetwork.getFlat();
for (int layer = flatNetwork.getLayerIndex().length - 1; layer > 0; layer--) {
int layerOutputCnt = flatNetwork.getLayerFeedCounts()[layer - 1];
int layerInputCnt = flatNetwork.getLayerCounts()[layer];
System.out.println("Weight index for layer " + (flatNetwork.getLayerIndex().length - layer));
int extendedLayerInputCnt = extendedFlatNetwork.getLayerCounts()[layer];
int indexPos = flatNetwork.getWeightIndex()[layer - 1];
int extendedIndexPos = extendedFlatNetwork.getWeightIndex()[layer - 1];
for (int i = 0; i < layerOutputCnt; i++) {
for (int j = 0; j < layerInputCnt; j++) {
int weightIndex = indexPos + (i * layerInputCnt) + j;
int extendedWeightIndex = extendedIndexPos + (i * extendedLayerInputCnt) + j;
if (j == layerInputCnt - 1) {
// move bias to end
extendedWeightIndex = extendedIndexPos + (i * extendedLayerInputCnt) + (extendedLayerInputCnt - 1);
}
System.out.println(weightIndex + " --> " + extendedWeightIndex);
}
}
}
}
use of org.encog.neural.flat.FlatNetwork in project shifu by ShifuML.
the class NNModelSpecTest method testModelFitIn.
@Test
public void testModelFitIn() {
PersistorRegistry.getInstance().add(new PersistBasicFloatNetwork());
BasicML basicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model5.nn")));
BasicNetwork basicNetwork = (BasicNetwork) basicML;
FlatNetwork flatNetwork = basicNetwork.getFlat();
BasicML extendedBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model6.nn")));
BasicNetwork extendedBasicNetwork = (BasicNetwork) extendedBasicML;
FlatNetwork extendedFlatNetwork = extendedBasicNetwork.getFlat();
NNMaster master = new NNMaster();
Set<Integer> fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, extendedFlatNetwork, Arrays.asList(new Integer[] { 1, 2, 3 }));
Assert.assertEquals(fixedWeightIndexSet.size(), 931);
fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, extendedFlatNetwork, Arrays.asList(new Integer[] { 1, 2, 3 }), false);
Assert.assertEquals(fixedWeightIndexSet.size(), 910);
}
Aggregations