Search in sources :

Example 1 with ActivationFunction

use of org.encog.engine.network.activation.ActivationFunction in project shifu by ShifuML.

the class Gradient method processLevel.

/**
 * Process one level.
 *
 * @param currentLevel
 *            The level.
 */
private void processLevel(final int currentLevel) {
    final int fromLayerIndex = this.layerIndex[currentLevel + 1];
    final int toLayerIndex = this.layerIndex[currentLevel];
    final int fromLayerSize = this.layerCounts[currentLevel + 1];
    final int toLayerSize = this.layerFeedCounts[currentLevel];
    final int index = this.weightIndex[currentLevel];
    final ActivationFunction activation = this.getNetwork().getActivationFunctions()[currentLevel + 1];
    final double currentFlatSpot = this.flatSpot[currentLevel + 1];
    // handle weights
    int yi = fromLayerIndex;
    for (int y = 0; y < fromLayerSize; y++) {
        final double output = this.layerOutput[yi];
        double sum = 0;
        int xi = toLayerIndex;
        int wi = index + y;
        for (int x = 0; x < toLayerSize; x++) {
            this.gradients[wi] += output * this.getLayerDelta()[xi];
            sum += this.weights[wi] * this.getLayerDelta()[xi];
            wi += fromLayerSize;
            xi++;
        }
        this.getLayerDelta()[yi] = sum * (activation.derivativeFunction(this.layerSums[yi], this.layerOutput[yi]) + currentFlatSpot);
        yi++;
    }
}
Also used : ActivationFunction(org.encog.engine.network.activation.ActivationFunction)

Example 2 with ActivationFunction

use of org.encog.engine.network.activation.ActivationFunction in project shifu by ShifuML.

the class PersistBasicFloatNetwork method saveNetwork.

public void saveNetwork(DataOutput out, final BasicFloatNetwork network) throws IOException {
    final FlatNetwork flat = network.getStructure().getFlat();
    // write general properties
    Map<String, String> properties = network.getProperties();
    if (properties == null) {
        out.writeInt(0);
    } else {
        out.writeInt(properties.size());
        for (Entry<String, String> entry : properties.entrySet()) {
            ml.shifu.shifu.core.dtrain.StringUtils.writeString(out, entry.getKey());
            ml.shifu.shifu.core.dtrain.StringUtils.writeString(out, entry.getValue());
        }
    }
    // write fields values in BasicFloatNetwork
    out.writeInt(flat.getBeginTraining());
    out.writeDouble(flat.getConnectionLimit());
    writeIntArray(out, flat.getContextTargetOffset());
    writeIntArray(out, flat.getContextTargetSize());
    out.writeInt(flat.getEndTraining());
    out.writeBoolean(flat.getHasContext());
    out.writeInt(flat.getInputCount());
    writeIntArray(out, flat.getLayerCounts());
    writeIntArray(out, flat.getLayerFeedCounts());
    writeIntArray(out, flat.getLayerContextCount());
    writeIntArray(out, flat.getLayerIndex());
    writeDoubleArray(out, flat.getLayerOutput());
    out.writeInt(flat.getOutputCount());
    writeIntArray(out, flat.getWeightIndex());
    writeDoubleArray(out, flat.getWeights());
    writeDoubleArray(out, flat.getBiasActivation());
    // write activation list
    out.writeInt(flat.getActivationFunctions().length);
    for (final ActivationFunction af : flat.getActivationFunctions()) {
        ml.shifu.shifu.core.dtrain.StringUtils.writeString(out, af.getClass().getSimpleName());
        writeDoubleArray(out, af.getParams());
    }
    // write sub sets
    Set<Integer> featureList = network.getFeatureSet();
    if (featureList == null || featureList.size() == 0) {
        out.writeInt(0);
    } else {
        out.writeInt(featureList.size());
        for (Integer integer : featureList) {
            out.writeInt(integer);
        }
    }
}
Also used : FlatNetwork(org.encog.neural.flat.FlatNetwork) ActivationFunction(org.encog.engine.network.activation.ActivationFunction)

Example 3 with ActivationFunction

use of org.encog.engine.network.activation.ActivationFunction in project shifu by ShifuML.

the class DTrainTest method initGradient.

public Gradient initGradient(MLDataSet training) {
    FlatNetwork flat = network.getFlat().clone();
    // copy Propagation from encog
    double[] flatSpot = new double[flat.getActivationFunctions().length];
    for (int i = 0; i < flat.getActivationFunctions().length; i++) {
        final ActivationFunction af = flat.getActivationFunctions()[i];
        if (af instanceof ActivationSigmoid) {
            flatSpot[i] = 0.1;
        } else {
            flatSpot[i] = 0.0;
        }
    }
    return new Gradient(flat, training.openAdditional(), training, flatSpot, new LinearErrorFunction(), false);
}
Also used : LinearErrorFunction(org.encog.neural.error.LinearErrorFunction) FlatNetwork(org.encog.neural.flat.FlatNetwork) ActivationFunction(org.encog.engine.network.activation.ActivationFunction) ActivationSigmoid(org.encog.engine.network.activation.ActivationSigmoid)

Example 4 with ActivationFunction

use of org.encog.engine.network.activation.ActivationFunction in project shifu by ShifuML.

the class PersistBasicFloatNetwork method save.

/**
 * {@inheritDoc}
 */
@Override
public final void save(final OutputStream os, final Object obj) {
    final EncogWriteHelper out = new EncogWriteHelper(os);
    final BasicFloatNetwork net = (BasicFloatNetwork) obj;
    final FlatNetwork flat = net.getStructure().getFlat();
    out.addSection("BASIC");
    out.addSubSection("PARAMS");
    out.addProperties(net.getProperties());
    out.addSubSection("NETWORK");
    out.writeProperty(BasicNetwork.TAG_BEGIN_TRAINING, flat.getBeginTraining());
    out.writeProperty(BasicNetwork.TAG_CONNECTION_LIMIT, flat.getConnectionLimit());
    out.writeProperty(BasicNetwork.TAG_CONTEXT_TARGET_OFFSET, flat.getContextTargetOffset());
    out.writeProperty(BasicNetwork.TAG_CONTEXT_TARGET_SIZE, flat.getContextTargetSize());
    out.writeProperty(BasicNetwork.TAG_END_TRAINING, flat.getEndTraining());
    out.writeProperty(BasicNetwork.TAG_HAS_CONTEXT, flat.getHasContext());
    out.writeProperty(PersistConst.INPUT_COUNT, flat.getInputCount());
    out.writeProperty(BasicNetwork.TAG_LAYER_COUNTS, flat.getLayerCounts());
    out.writeProperty(BasicNetwork.TAG_LAYER_FEED_COUNTS, flat.getLayerFeedCounts());
    out.writeProperty(BasicNetwork.TAG_LAYER_CONTEXT_COUNT, flat.getLayerContextCount());
    out.writeProperty(BasicNetwork.TAG_LAYER_INDEX, flat.getLayerIndex());
    out.writeProperty(PersistConst.OUTPUT, flat.getLayerOutput());
    out.writeProperty(PersistConst.OUTPUT_COUNT, flat.getOutputCount());
    out.writeProperty(BasicNetwork.TAG_WEIGHT_INDEX, flat.getWeightIndex());
    out.writeProperty(PersistConst.WEIGHTS, flat.getWeights());
    out.writeProperty(BasicNetwork.TAG_BIAS_ACTIVATION, flat.getBiasActivation());
    out.addSubSection("ACTIVATION");
    for (final ActivationFunction af : flat.getActivationFunctions()) {
        out.addColumn(af.getClass().getSimpleName());
        for (int i = 0; i < af.getParams().length; i++) {
            out.addColumn(af.getParams()[i]);
        }
        out.writeLine();
    }
    out.addSubSection("SUBSET");
    Set<Integer> featureList = net.getFeatureSet();
    if (featureList == null || featureList.size() == 0) {
        out.writeProperty("SUBSETFEATURES", "");
    } else {
        String subFeaturesStr = StringUtils.join(featureList, ",");
        out.writeProperty("SUBSETFEATURES", subFeaturesStr);
    }
    out.flush();
}
Also used : FlatNetwork(org.encog.neural.flat.FlatNetwork) ActivationFunction(org.encog.engine.network.activation.ActivationFunction)

Example 5 with ActivationFunction

use of org.encog.engine.network.activation.ActivationFunction in project shifu by ShifuML.

the class PersistBasicFloatNetwork method read.

/**
 * {@inheritDoc}
 */
@Override
public final Object read(final InputStream is) {
    final BasicFloatNetwork result = new BasicFloatNetwork();
    final FlatNetwork flat = new FlatNetwork();
    final EncogReadHelper in = new EncogReadHelper(is);
    EncogFileSection section;
    while ((section = in.readNextSection()) != null) {
        if (section.getSectionName().equals("BASIC") && section.getSubSectionName().equals("PARAMS")) {
            final Map<String, String> params = section.parseParams();
            result.getProperties().putAll(params);
        }
        if (section.getSectionName().equals("BASIC") && section.getSubSectionName().equals("NETWORK")) {
            final Map<String, String> params = section.parseParams();
            flat.setBeginTraining(EncogFileSection.parseInt(params, BasicNetwork.TAG_BEGIN_TRAINING));
            flat.setConnectionLimit(EncogFileSection.parseDouble(params, BasicNetwork.TAG_CONNECTION_LIMIT));
            flat.setContextTargetOffset(EncogFileSection.parseIntArray(params, BasicNetwork.TAG_CONTEXT_TARGET_OFFSET));
            flat.setContextTargetSize(EncogFileSection.parseIntArray(params, BasicNetwork.TAG_CONTEXT_TARGET_SIZE));
            flat.setEndTraining(EncogFileSection.parseInt(params, BasicNetwork.TAG_END_TRAINING));
            flat.setHasContext(EncogFileSection.parseBoolean(params, BasicNetwork.TAG_HAS_CONTEXT));
            flat.setInputCount(EncogFileSection.parseInt(params, PersistConst.INPUT_COUNT));
            flat.setLayerCounts(EncogFileSection.parseIntArray(params, BasicNetwork.TAG_LAYER_COUNTS));
            flat.setLayerFeedCounts(EncogFileSection.parseIntArray(params, BasicNetwork.TAG_LAYER_FEED_COUNTS));
            flat.setLayerContextCount(EncogFileSection.parseIntArray(params, BasicNetwork.TAG_LAYER_CONTEXT_COUNT));
            flat.setLayerIndex(EncogFileSection.parseIntArray(params, BasicNetwork.TAG_LAYER_INDEX));
            flat.setLayerOutput(EncogFileSection.parseDoubleArray(params, PersistConst.OUTPUT));
            flat.setLayerSums(new double[flat.getLayerOutput().length]);
            flat.setOutputCount(EncogFileSection.parseInt(params, PersistConst.OUTPUT_COUNT));
            flat.setWeightIndex(EncogFileSection.parseIntArray(params, BasicNetwork.TAG_WEIGHT_INDEX));
            flat.setWeights(EncogFileSection.parseDoubleArray(params, PersistConst.WEIGHTS));
            flat.setBiasActivation(EncogFileSection.parseDoubleArray(params, BasicNetwork.TAG_BIAS_ACTIVATION));
        } else if (section.getSectionName().equals("BASIC") && section.getSubSectionName().equals("ACTIVATION")) {
            int index = 0;
            flat.setActivationFunctions(new ActivationFunction[flat.getLayerCounts().length]);
            for (final String line : section.getLines()) {
                ActivationFunction af = null;
                final List<String> cols = EncogFileSection.splitColumns(line);
                String name = "org.encog.engine.network.activation." + cols.get(0);
                if (cols.get(0).equals("ActivationReLU")) {
                    name = "ml.shifu.shifu.core.dtrain.nn.ActivationReLU";
                } else if (cols.get(0).equals("ActivationLeakyReLU")) {
                    name = "ml.shifu.shifu.core.dtrain.nn.ActivationLeakyReLU";
                } else if (cols.get(0).equals("ActivationSwish")) {
                    name = "ml.shifu.shifu.core.dtrain.nn.ActivationSwish";
                } else if (cols.get(0).equals("ActivationPTANH")) {
                    name = "ml.shifu.shifu.core.dtrain.nn.ActivationPTANH";
                }
                try {
                    final Class<?> clazz = Class.forName(name);
                    af = (ActivationFunction) clazz.newInstance();
                } catch (final ClassNotFoundException e) {
                    throw new PersistError(e);
                } catch (final InstantiationException e) {
                    throw new PersistError(e);
                } catch (final IllegalAccessException e) {
                    throw new PersistError(e);
                }
                for (int i = 0; i < af.getParamNames().length; i++) {
                    af.setParam(i, CSVFormat.EG_FORMAT.parse(cols.get(i + 1)));
                }
                flat.getActivationFunctions()[index++] = af;
            }
        } else if (section.getSectionName().equals("BASIC") && section.getSubSectionName().equals("SUBSET")) {
            final Map<String, String> params = section.parseParams();
            String subsetStr = params.get("SUBSETFEATURES");
            if (StringUtils.isBlank(subsetStr)) {
                result.setFeatureSet(null);
            } else {
                String[] splits = subsetStr.split(",");
                Set<Integer> subFeatures = new HashSet<Integer>();
                for (String split : splits) {
                    int featureIndex = Integer.parseInt(split);
                    subFeatures.add(featureIndex);
                }
                result.setFeatureSet(subFeatures);
            }
        }
    }
    result.getStructure().setFlat(flat);
    return result;
}
Also used : FlatNetwork(org.encog.neural.flat.FlatNetwork) ActivationFunction(org.encog.engine.network.activation.ActivationFunction)

Aggregations

ActivationFunction (org.encog.engine.network.activation.ActivationFunction)7 FlatNetwork (org.encog.neural.flat.FlatNetwork)5 ActivationReLU (ml.shifu.shifu.core.dtrain.nn.ActivationReLU)1 ActivationSwish (ml.shifu.shifu.core.dtrain.nn.ActivationSwish)1 ActivationSigmoid (org.encog.engine.network.activation.ActivationSigmoid)1 LinearErrorFunction (org.encog.neural.error.LinearErrorFunction)1