Search in sources :

Example 6 with FlatNetwork

use of org.encog.neural.flat.FlatNetwork in project shifu by ShifuML.

the class PersistBasicFloatNetwork method save.

/**
 * {@inheritDoc}
 */
@Override
public final void save(final OutputStream os, final Object obj) {
    final EncogWriteHelper out = new EncogWriteHelper(os);
    final BasicFloatNetwork net = (BasicFloatNetwork) obj;
    final FlatNetwork flat = net.getStructure().getFlat();
    out.addSection("BASIC");
    out.addSubSection("PARAMS");
    out.addProperties(net.getProperties());
    out.addSubSection("NETWORK");
    out.writeProperty(BasicNetwork.TAG_BEGIN_TRAINING, flat.getBeginTraining());
    out.writeProperty(BasicNetwork.TAG_CONNECTION_LIMIT, flat.getConnectionLimit());
    out.writeProperty(BasicNetwork.TAG_CONTEXT_TARGET_OFFSET, flat.getContextTargetOffset());
    out.writeProperty(BasicNetwork.TAG_CONTEXT_TARGET_SIZE, flat.getContextTargetSize());
    out.writeProperty(BasicNetwork.TAG_END_TRAINING, flat.getEndTraining());
    out.writeProperty(BasicNetwork.TAG_HAS_CONTEXT, flat.getHasContext());
    out.writeProperty(PersistConst.INPUT_COUNT, flat.getInputCount());
    out.writeProperty(BasicNetwork.TAG_LAYER_COUNTS, flat.getLayerCounts());
    out.writeProperty(BasicNetwork.TAG_LAYER_FEED_COUNTS, flat.getLayerFeedCounts());
    out.writeProperty(BasicNetwork.TAG_LAYER_CONTEXT_COUNT, flat.getLayerContextCount());
    out.writeProperty(BasicNetwork.TAG_LAYER_INDEX, flat.getLayerIndex());
    out.writeProperty(PersistConst.OUTPUT, flat.getLayerOutput());
    out.writeProperty(PersistConst.OUTPUT_COUNT, flat.getOutputCount());
    out.writeProperty(BasicNetwork.TAG_WEIGHT_INDEX, flat.getWeightIndex());
    out.writeProperty(PersistConst.WEIGHTS, flat.getWeights());
    out.writeProperty(BasicNetwork.TAG_BIAS_ACTIVATION, flat.getBiasActivation());
    out.addSubSection("ACTIVATION");
    for (final ActivationFunction af : flat.getActivationFunctions()) {
        out.addColumn(af.getClass().getSimpleName());
        for (int i = 0; i < af.getParams().length; i++) {
            out.addColumn(af.getParams()[i]);
        }
        out.writeLine();
    }
    out.addSubSection("SUBSET");
    Set<Integer> featureList = net.getFeatureSet();
    if (featureList == null || featureList.size() == 0) {
        out.writeProperty("SUBSETFEATURES", "");
    } else {
        String subFeaturesStr = StringUtils.join(featureList, ",");
        out.writeProperty("SUBSETFEATURES", subFeaturesStr);
    }
    out.flush();
}
Also used : FlatNetwork(org.encog.neural.flat.FlatNetwork) ActivationFunction(org.encog.engine.network.activation.ActivationFunction)

Example 7 with FlatNetwork

use of org.encog.neural.flat.FlatNetwork in project shifu by ShifuML.

the class PersistBasicFloatNetwork method read.

/**
 * {@inheritDoc}
 */
@Override
public final Object read(final InputStream is) {
    final BasicFloatNetwork result = new BasicFloatNetwork();
    final FlatNetwork flat = new FlatNetwork();
    final EncogReadHelper in = new EncogReadHelper(is);
    EncogFileSection section;
    while ((section = in.readNextSection()) != null) {
        if (section.getSectionName().equals("BASIC") && section.getSubSectionName().equals("PARAMS")) {
            final Map<String, String> params = section.parseParams();
            result.getProperties().putAll(params);
        }
        if (section.getSectionName().equals("BASIC") && section.getSubSectionName().equals("NETWORK")) {
            final Map<String, String> params = section.parseParams();
            flat.setBeginTraining(EncogFileSection.parseInt(params, BasicNetwork.TAG_BEGIN_TRAINING));
            flat.setConnectionLimit(EncogFileSection.parseDouble(params, BasicNetwork.TAG_CONNECTION_LIMIT));
            flat.setContextTargetOffset(EncogFileSection.parseIntArray(params, BasicNetwork.TAG_CONTEXT_TARGET_OFFSET));
            flat.setContextTargetSize(EncogFileSection.parseIntArray(params, BasicNetwork.TAG_CONTEXT_TARGET_SIZE));
            flat.setEndTraining(EncogFileSection.parseInt(params, BasicNetwork.TAG_END_TRAINING));
            flat.setHasContext(EncogFileSection.parseBoolean(params, BasicNetwork.TAG_HAS_CONTEXT));
            flat.setInputCount(EncogFileSection.parseInt(params, PersistConst.INPUT_COUNT));
            flat.setLayerCounts(EncogFileSection.parseIntArray(params, BasicNetwork.TAG_LAYER_COUNTS));
            flat.setLayerFeedCounts(EncogFileSection.parseIntArray(params, BasicNetwork.TAG_LAYER_FEED_COUNTS));
            flat.setLayerContextCount(EncogFileSection.parseIntArray(params, BasicNetwork.TAG_LAYER_CONTEXT_COUNT));
            flat.setLayerIndex(EncogFileSection.parseIntArray(params, BasicNetwork.TAG_LAYER_INDEX));
            flat.setLayerOutput(EncogFileSection.parseDoubleArray(params, PersistConst.OUTPUT));
            flat.setLayerSums(new double[flat.getLayerOutput().length]);
            flat.setOutputCount(EncogFileSection.parseInt(params, PersistConst.OUTPUT_COUNT));
            flat.setWeightIndex(EncogFileSection.parseIntArray(params, BasicNetwork.TAG_WEIGHT_INDEX));
            flat.setWeights(EncogFileSection.parseDoubleArray(params, PersistConst.WEIGHTS));
            flat.setBiasActivation(EncogFileSection.parseDoubleArray(params, BasicNetwork.TAG_BIAS_ACTIVATION));
        } else if (section.getSectionName().equals("BASIC") && section.getSubSectionName().equals("ACTIVATION")) {
            int index = 0;
            flat.setActivationFunctions(new ActivationFunction[flat.getLayerCounts().length]);
            for (final String line : section.getLines()) {
                ActivationFunction af = null;
                final List<String> cols = EncogFileSection.splitColumns(line);
                String name = "org.encog.engine.network.activation." + cols.get(0);
                if (cols.get(0).equals("ActivationReLU")) {
                    name = "ml.shifu.shifu.core.dtrain.nn.ActivationReLU";
                } else if (cols.get(0).equals("ActivationLeakyReLU")) {
                    name = "ml.shifu.shifu.core.dtrain.nn.ActivationLeakyReLU";
                } else if (cols.get(0).equals("ActivationSwish")) {
                    name = "ml.shifu.shifu.core.dtrain.nn.ActivationSwish";
                } else if (cols.get(0).equals("ActivationPTANH")) {
                    name = "ml.shifu.shifu.core.dtrain.nn.ActivationPTANH";
                }
                try {
                    final Class<?> clazz = Class.forName(name);
                    af = (ActivationFunction) clazz.newInstance();
                } catch (final ClassNotFoundException e) {
                    throw new PersistError(e);
                } catch (final InstantiationException e) {
                    throw new PersistError(e);
                } catch (final IllegalAccessException e) {
                    throw new PersistError(e);
                }
                for (int i = 0; i < af.getParamNames().length; i++) {
                    af.setParam(i, CSVFormat.EG_FORMAT.parse(cols.get(i + 1)));
                }
                flat.getActivationFunctions()[index++] = af;
            }
        } else if (section.getSectionName().equals("BASIC") && section.getSubSectionName().equals("SUBSET")) {
            final Map<String, String> params = section.parseParams();
            String subsetStr = params.get("SUBSETFEATURES");
            if (StringUtils.isBlank(subsetStr)) {
                result.setFeatureSet(null);
            } else {
                String[] splits = subsetStr.split(",");
                Set<Integer> subFeatures = new HashSet<Integer>();
                for (String split : splits) {
                    int featureIndex = Integer.parseInt(split);
                    subFeatures.add(featureIndex);
                }
                result.setFeatureSet(subFeatures);
            }
        }
    }
    result.getStructure().setFlat(flat);
    return result;
}
Also used : FlatNetwork(org.encog.neural.flat.FlatNetwork) ActivationFunction(org.encog.engine.network.activation.ActivationFunction)

Example 8 with FlatNetwork

use of org.encog.neural.flat.FlatNetwork in project shifu by ShifuML.

the class PersistBasicFloatNetwork method readNetwork.

public BasicFloatNetwork readNetwork(final DataInput in) throws IOException {
    final BasicFloatNetwork result = new BasicFloatNetwork();
    final FlatNetwork flat = new FlatNetwork();
    // read properties
    Map<String, String> properties = new HashMap<String, String>();
    int size = in.readInt();
    for (int i = 0; i < size; i++) {
        properties.put(ml.shifu.shifu.core.dtrain.StringUtils.readString(in), ml.shifu.shifu.core.dtrain.StringUtils.readString(in));
    }
    result.getProperties().putAll(properties);
    // read fields
    flat.setBeginTraining(in.readInt());
    flat.setConnectionLimit(in.readDouble());
    flat.setContextTargetOffset(readIntArray(in));
    flat.setContextTargetSize(readIntArray(in));
    flat.setEndTraining(in.readInt());
    flat.setHasContext(in.readBoolean());
    flat.setInputCount(in.readInt());
    flat.setLayerCounts(readIntArray(in));
    flat.setLayerFeedCounts(readIntArray(in));
    flat.setLayerContextCount(readIntArray(in));
    flat.setLayerIndex(readIntArray(in));
    flat.setLayerOutput(readDoubleArray(in));
    flat.setOutputCount(in.readInt());
    flat.setLayerSums(new double[flat.getLayerOutput().length]);
    flat.setWeightIndex(readIntArray(in));
    flat.setWeights(readDoubleArray(in));
    flat.setBiasActivation(readDoubleArray(in));
    // read activations
    flat.setActivationFunctions(new ActivationFunction[flat.getLayerCounts().length]);
    int acSize = in.readInt();
    for (int i = 0; i < acSize; i++) {
        String name = ml.shifu.shifu.core.dtrain.StringUtils.readString(in);
        if (name.equals("ActivationReLU")) {
            name = ActivationReLU.class.getName();
        } else if (name.equals("ActivationLeakyReLU")) {
            name = ActivationLeakyReLU.class.getName();
        } else if (name.equals("ActivationSwish")) {
            name = ActivationSwish.class.getName();
        } else if (name.equals("ActivationPTANH")) {
            name = ActivationPTANH.class.getName();
        } else {
            name = "org.encog.engine.network.activation." + name;
        }
        ActivationFunction af = null;
        try {
            final Class<?> clazz = Class.forName(name);
            af = (ActivationFunction) clazz.newInstance();
        } catch (final ClassNotFoundException e) {
            throw new PersistError(e);
        } catch (final InstantiationException e) {
            throw new PersistError(e);
        } catch (final IllegalAccessException e) {
            throw new PersistError(e);
        }
        double[] params = readDoubleArray(in);
        for (int j = 0; j < params.length; j++) {
            af.setParam(j, params[j]);
        }
        flat.getActivationFunctions()[i] = af;
    }
    // read subset
    int subsetSize = in.readInt();
    Set<Integer> featureList = new HashSet<Integer>();
    for (int i = 0; i < subsetSize; i++) {
        featureList.add(in.readInt());
    }
    result.setFeatureSet(featureList);
    result.getStructure().setFlat(flat);
    return result;
}
Also used : ActivationSwish(ml.shifu.shifu.core.dtrain.nn.ActivationSwish) FlatNetwork(org.encog.neural.flat.FlatNetwork) ActivationFunction(org.encog.engine.network.activation.ActivationFunction) ActivationReLU(ml.shifu.shifu.core.dtrain.nn.ActivationReLU)

Example 9 with FlatNetwork

use of org.encog.neural.flat.FlatNetwork in project shifu by ShifuML.

the class NNModelSpecTest method testModelStructureCompare.

@Test
public void testModelStructureCompare() {
    BasicML basicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model0.nn")));
    BasicNetwork basicNetwork = (BasicNetwork) basicML;
    FlatNetwork flatNetwork = basicNetwork.getFlat();
    BasicML extendedBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model1.nn")));
    BasicNetwork extendedBasicNetwork = (BasicNetwork) extendedBasicML;
    FlatNetwork extendedFlatNetwork = extendedBasicNetwork.getFlat();
    Assert.assertEquals(new NNStructureComparator().compare(flatNetwork, extendedFlatNetwork), -1);
    Assert.assertEquals(new NNStructureComparator().compare(flatNetwork, flatNetwork), 0);
    Assert.assertEquals(new NNStructureComparator().compare(extendedFlatNetwork, flatNetwork), 1);
    BasicML diffBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model2.nn")));
    BasicNetwork diffBasicNetwork = (BasicNetwork) diffBasicML;
    FlatNetwork diffFlatNetwork = diffBasicNetwork.getFlat();
    Assert.assertEquals(new NNStructureComparator().compare(flatNetwork, diffFlatNetwork), -1);
    Assert.assertEquals(new NNStructureComparator().compare(diffFlatNetwork, flatNetwork), -1);
    Assert.assertEquals(new NNStructureComparator().compare(extendedFlatNetwork, diffFlatNetwork), 1);
    Assert.assertEquals(new NNStructureComparator().compare(diffFlatNetwork, extendedFlatNetwork), -1);
    BasicML deepBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model3.nn")));
    BasicNetwork deppBasicNetwork = (BasicNetwork) deepBasicML;
    FlatNetwork deepFlatNetwork = deppBasicNetwork.getFlat();
    Assert.assertEquals(new NNStructureComparator().compare(deepFlatNetwork, flatNetwork), 1);
    Assert.assertEquals(new NNStructureComparator().compare(flatNetwork, deepFlatNetwork), -1);
}
Also used : FlatNetwork(org.encog.neural.flat.FlatNetwork) BasicNetwork(org.encog.neural.networks.BasicNetwork) NNStructureComparator(ml.shifu.shifu.core.dtrain.nn.NNStructureComparator) BasicML(org.encog.ml.BasicML) File(java.io.File) Test(org.testng.annotations.Test)

Example 10 with FlatNetwork

use of org.encog.neural.flat.FlatNetwork in project shifu by ShifuML.

the class NNModelSpecTest method testFitExistingModelIn.

@Test
public void testFitExistingModelIn() {
    BasicML basicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model0.nn")));
    BasicNetwork basicNetwork = (BasicNetwork) basicML;
    FlatNetwork flatNetwork = basicNetwork.getFlat();
    NNMaster master = new NNMaster();
    Set<Integer> fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, flatNetwork, Arrays.asList(new Integer[] { 6 }));
    List<Integer> indexList = new ArrayList<Integer>(fixedWeightIndexSet);
    Collections.sort(indexList);
    Assert.assertEquals(indexList.size(), 31);
    fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, flatNetwork, Arrays.asList(new Integer[] { 1 }));
    indexList = new ArrayList<Integer>(fixedWeightIndexSet);
    Collections.sort(indexList);
    Assert.assertEquals(indexList.size(), 930);
    BasicML extendedBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model1.nn")));
    BasicNetwork extendedBasicNetwork = (BasicNetwork) extendedBasicML;
    FlatNetwork extendedFlatNetwork = extendedBasicNetwork.getFlat();
    fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, extendedFlatNetwork, Arrays.asList(new Integer[] { 1 }));
    indexList = new ArrayList<Integer>(fixedWeightIndexSet);
    Collections.sort(indexList);
    Assert.assertEquals(indexList.size(), 930);
    fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, extendedFlatNetwork, Arrays.asList(new Integer[] { 1 }), false);
    indexList = new ArrayList<Integer>(fixedWeightIndexSet);
    Collections.sort(indexList);
    Assert.assertEquals(indexList.size(), 900);
}
Also used : FlatNetwork(org.encog.neural.flat.FlatNetwork) NNMaster(ml.shifu.shifu.core.dtrain.nn.NNMaster) BasicNetwork(org.encog.neural.networks.BasicNetwork) ArrayList(java.util.ArrayList) BasicML(org.encog.ml.BasicML) File(java.io.File) Test(org.testng.annotations.Test)

Aggregations

FlatNetwork (org.encog.neural.flat.FlatNetwork)10 ActivationFunction (org.encog.engine.network.activation.ActivationFunction)5 BasicNetwork (org.encog.neural.networks.BasicNetwork)5 File (java.io.File)4 BasicML (org.encog.ml.BasicML)4 Test (org.testng.annotations.Test)3 ArrayList (java.util.ArrayList)2 NNMaster (ml.shifu.shifu.core.dtrain.nn.NNMaster)2 ActivationSigmoid (org.encog.engine.network.activation.ActivationSigmoid)2 LinearErrorFunction (org.encog.neural.error.LinearErrorFunction)2 List (java.util.List)1 FloatFlatNetwork (ml.shifu.shifu.core.dtrain.dataset.FloatFlatNetwork)1 PersistBasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork)1 ActivationReLU (ml.shifu.shifu.core.dtrain.nn.ActivationReLU)1 ActivationSwish (ml.shifu.shifu.core.dtrain.nn.ActivationSwish)1 NNStructureComparator (ml.shifu.shifu.core.dtrain.nn.NNStructureComparator)1