Search in sources :

Example 1 with FlatNetwork

use of org.encog.neural.flat.FlatNetwork in project shifu by ShifuML.

the class PersistBasicFloatNetwork method saveNetwork.

public void saveNetwork(DataOutput out, final BasicFloatNetwork network) throws IOException {
    final FlatNetwork flat = network.getStructure().getFlat();
    // write general properties
    Map<String, String> properties = network.getProperties();
    if (properties == null) {
        out.writeInt(0);
    } else {
        out.writeInt(properties.size());
        for (Entry<String, String> entry : properties.entrySet()) {
            ml.shifu.shifu.core.dtrain.StringUtils.writeString(out, entry.getKey());
            ml.shifu.shifu.core.dtrain.StringUtils.writeString(out, entry.getValue());
        }
    }
    // write fields values in BasicFloatNetwork
    out.writeInt(flat.getBeginTraining());
    out.writeDouble(flat.getConnectionLimit());
    writeIntArray(out, flat.getContextTargetOffset());
    writeIntArray(out, flat.getContextTargetSize());
    out.writeInt(flat.getEndTraining());
    out.writeBoolean(flat.getHasContext());
    out.writeInt(flat.getInputCount());
    writeIntArray(out, flat.getLayerCounts());
    writeIntArray(out, flat.getLayerFeedCounts());
    writeIntArray(out, flat.getLayerContextCount());
    writeIntArray(out, flat.getLayerIndex());
    writeDoubleArray(out, flat.getLayerOutput());
    out.writeInt(flat.getOutputCount());
    writeIntArray(out, flat.getWeightIndex());
    writeDoubleArray(out, flat.getWeights());
    writeDoubleArray(out, flat.getBiasActivation());
    // write activation list
    out.writeInt(flat.getActivationFunctions().length);
    for (final ActivationFunction af : flat.getActivationFunctions()) {
        ml.shifu.shifu.core.dtrain.StringUtils.writeString(out, af.getClass().getSimpleName());
        writeDoubleArray(out, af.getParams());
    }
    // write sub sets
    Set<Integer> featureList = network.getFeatureSet();
    if (featureList == null || featureList.size() == 0) {
        out.writeInt(0);
    } else {
        out.writeInt(featureList.size());
        for (Integer integer : featureList) {
            out.writeInt(integer);
        }
    }
}
Also used : FlatNetwork(org.encog.neural.flat.FlatNetwork) ActivationFunction(org.encog.engine.network.activation.ActivationFunction)

Example 2 with FlatNetwork

use of org.encog.neural.flat.FlatNetwork in project shifu by ShifuML.

the class AbstractNNWorker method initGradient.

@SuppressWarnings("unchecked")
private void initGradient(FloatMLDataSet training, FloatMLDataSet testing, double[] weights, boolean isCrossOver) {
    int numLayers = (Integer) this.validParams.get(CommonConstants.NUM_HIDDEN_LAYERS);
    List<String> actFunc = (List<String>) this.validParams.get(CommonConstants.ACTIVATION_FUNC);
    List<Integer> hiddenNodeList = (List<Integer>) this.validParams.get(CommonConstants.NUM_HIDDEN_NODES);
    String outputActivationFunc = (String) validParams.get(CommonConstants.OUTPUT_ACTIVATION_FUNC);
    BasicNetwork network = DTrainUtils.generateNetwork(this.featureInputsCnt, this.outputNodeCount, numLayers, actFunc, hiddenNodeList, false, this.dropoutRate, this.wgtInit, CommonUtils.isLinearTarget(modelConfig, columnConfigList), outputActivationFunc);
    // use the weights from master
    network.getFlat().setWeights(weights);
    FlatNetwork flat = network.getFlat();
    // copy Propagation from encog, fix flat spot problem
    double[] flatSpot = new double[flat.getActivationFunctions().length];
    for (int i = 0; i < flat.getActivationFunctions().length; i++) {
        flatSpot[i] = flat.getActivationFunctions()[i] instanceof ActivationSigmoid ? 0.1 : 0.0;
    }
    LOG.info("Gradient computing thread count is {}.", modelConfig.getTrain().getWorkerThreadCount());
    this.gradient = new ParallelGradient((FloatFlatNetwork) flat, training, testing, flatSpot, new LinearErrorFunction(), isCrossOver, modelConfig.getTrain().getWorkerThreadCount(), this.lossStr, this.batchs);
}
Also used : LinearErrorFunction(org.encog.neural.error.LinearErrorFunction) FloatFlatNetwork(ml.shifu.shifu.core.dtrain.dataset.FloatFlatNetwork) FlatNetwork(org.encog.neural.flat.FlatNetwork) FloatFlatNetwork(ml.shifu.shifu.core.dtrain.dataset.FloatFlatNetwork) BasicNetwork(org.encog.neural.networks.BasicNetwork) ActivationSigmoid(org.encog.engine.network.activation.ActivationSigmoid) ArrayList(java.util.ArrayList) List(java.util.List)

Example 3 with FlatNetwork

use of org.encog.neural.flat.FlatNetwork in project shifu by ShifuML.

the class DTrainTest method initGradient.

public Gradient initGradient(MLDataSet training) {
    FlatNetwork flat = network.getFlat().clone();
    // copy Propagation from encog
    double[] flatSpot = new double[flat.getActivationFunctions().length];
    for (int i = 0; i < flat.getActivationFunctions().length; i++) {
        final ActivationFunction af = flat.getActivationFunctions()[i];
        if (af instanceof ActivationSigmoid) {
            flatSpot[i] = 0.1;
        } else {
            flatSpot[i] = 0.0;
        }
    }
    return new Gradient(flat, training.openAdditional(), training, flatSpot, new LinearErrorFunction(), false);
}
Also used : LinearErrorFunction(org.encog.neural.error.LinearErrorFunction) FlatNetwork(org.encog.neural.flat.FlatNetwork) ActivationFunction(org.encog.engine.network.activation.ActivationFunction) ActivationSigmoid(org.encog.engine.network.activation.ActivationSigmoid)

Example 4 with FlatNetwork

use of org.encog.neural.flat.FlatNetwork in project shifu by ShifuML.

the class NNModelSpecTest method testModelTraverse.

// @Test
public void testModelTraverse() {
    BasicML basicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model0.nn")));
    BasicNetwork basicNetwork = (BasicNetwork) basicML;
    FlatNetwork flatNetwork = basicNetwork.getFlat();
    BasicML extendedBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model1.nn")));
    BasicNetwork extendedBasicNetwork = (BasicNetwork) extendedBasicML;
    FlatNetwork extendedFlatNetwork = extendedBasicNetwork.getFlat();
    for (int layer = flatNetwork.getLayerIndex().length - 1; layer > 0; layer--) {
        int layerOutputCnt = flatNetwork.getLayerFeedCounts()[layer - 1];
        int layerInputCnt = flatNetwork.getLayerCounts()[layer];
        System.out.println("Weight index for layer " + (flatNetwork.getLayerIndex().length - layer));
        int extendedLayerInputCnt = extendedFlatNetwork.getLayerCounts()[layer];
        int indexPos = flatNetwork.getWeightIndex()[layer - 1];
        int extendedIndexPos = extendedFlatNetwork.getWeightIndex()[layer - 1];
        for (int i = 0; i < layerOutputCnt; i++) {
            for (int j = 0; j < layerInputCnt; j++) {
                int weightIndex = indexPos + (i * layerInputCnt) + j;
                int extendedWeightIndex = extendedIndexPos + (i * extendedLayerInputCnt) + j;
                if (j == layerInputCnt - 1) {
                    // move bias to end
                    extendedWeightIndex = extendedIndexPos + (i * extendedLayerInputCnt) + (extendedLayerInputCnt - 1);
                }
                System.out.println(weightIndex + " --> " + extendedWeightIndex);
            }
        }
    }
}
Also used : FlatNetwork(org.encog.neural.flat.FlatNetwork) BasicNetwork(org.encog.neural.networks.BasicNetwork) BasicML(org.encog.ml.BasicML) File(java.io.File)

Example 5 with FlatNetwork

use of org.encog.neural.flat.FlatNetwork in project shifu by ShifuML.

the class NNModelSpecTest method testModelFitIn.

@Test
public void testModelFitIn() {
    PersistorRegistry.getInstance().add(new PersistBasicFloatNetwork());
    BasicML basicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model5.nn")));
    BasicNetwork basicNetwork = (BasicNetwork) basicML;
    FlatNetwork flatNetwork = basicNetwork.getFlat();
    BasicML extendedBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model6.nn")));
    BasicNetwork extendedBasicNetwork = (BasicNetwork) extendedBasicML;
    FlatNetwork extendedFlatNetwork = extendedBasicNetwork.getFlat();
    NNMaster master = new NNMaster();
    Set<Integer> fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, extendedFlatNetwork, Arrays.asList(new Integer[] { 1, 2, 3 }));
    Assert.assertEquals(fixedWeightIndexSet.size(), 931);
    fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, extendedFlatNetwork, Arrays.asList(new Integer[] { 1, 2, 3 }), false);
    Assert.assertEquals(fixedWeightIndexSet.size(), 910);
}
Also used : FlatNetwork(org.encog.neural.flat.FlatNetwork) NNMaster(ml.shifu.shifu.core.dtrain.nn.NNMaster) BasicNetwork(org.encog.neural.networks.BasicNetwork) BasicML(org.encog.ml.BasicML) File(java.io.File) PersistBasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork) Test(org.testng.annotations.Test)

Aggregations

FlatNetwork (org.encog.neural.flat.FlatNetwork)10 ActivationFunction (org.encog.engine.network.activation.ActivationFunction)5 BasicNetwork (org.encog.neural.networks.BasicNetwork)5 File (java.io.File)4 BasicML (org.encog.ml.BasicML)4 Test (org.testng.annotations.Test)3 ArrayList (java.util.ArrayList)2 NNMaster (ml.shifu.shifu.core.dtrain.nn.NNMaster)2 ActivationSigmoid (org.encog.engine.network.activation.ActivationSigmoid)2 LinearErrorFunction (org.encog.neural.error.LinearErrorFunction)2 List (java.util.List)1 FloatFlatNetwork (ml.shifu.shifu.core.dtrain.dataset.FloatFlatNetwork)1 PersistBasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork)1 ActivationReLU (ml.shifu.shifu.core.dtrain.nn.ActivationReLU)1 ActivationSwish (ml.shifu.shifu.core.dtrain.nn.ActivationSwish)1 NNStructureComparator (ml.shifu.shifu.core.dtrain.nn.NNStructureComparator)1