use of org.encog.engine.network.activation.ActivationFunction in project shifu by ShifuML.
the class Gradient method processLevel.
/**
* Process one level.
*
* @param currentLevel
* The level.
*/
private void processLevel(final int currentLevel) {
final int fromLayerIndex = this.layerIndex[currentLevel + 1];
final int toLayerIndex = this.layerIndex[currentLevel];
final int fromLayerSize = this.layerCounts[currentLevel + 1];
final int toLayerSize = this.layerFeedCounts[currentLevel];
final int index = this.weightIndex[currentLevel];
final ActivationFunction activation = this.getNetwork().getActivationFunctions()[currentLevel + 1];
final double currentFlatSpot = this.flatSpot[currentLevel + 1];
// handle weights
int yi = fromLayerIndex;
for (int y = 0; y < fromLayerSize; y++) {
final double output = this.layerOutput[yi];
double sum = 0;
int xi = toLayerIndex;
int wi = index + y;
for (int x = 0; x < toLayerSize; x++) {
this.gradients[wi] += output * this.getLayerDelta()[xi];
sum += this.weights[wi] * this.getLayerDelta()[xi];
wi += fromLayerSize;
xi++;
}
this.getLayerDelta()[yi] = sum * (activation.derivativeFunction(this.layerSums[yi], this.layerOutput[yi]) + currentFlatSpot);
yi++;
}
}
use of org.encog.engine.network.activation.ActivationFunction in project shifu by ShifuML.
the class PersistBasicFloatNetwork method saveNetwork.
public void saveNetwork(DataOutput out, final BasicFloatNetwork network) throws IOException {
final FlatNetwork flat = network.getStructure().getFlat();
// write general properties
Map<String, String> properties = network.getProperties();
if (properties == null) {
out.writeInt(0);
} else {
out.writeInt(properties.size());
for (Entry<String, String> entry : properties.entrySet()) {
ml.shifu.shifu.core.dtrain.StringUtils.writeString(out, entry.getKey());
ml.shifu.shifu.core.dtrain.StringUtils.writeString(out, entry.getValue());
}
}
// write fields values in BasicFloatNetwork
out.writeInt(flat.getBeginTraining());
out.writeDouble(flat.getConnectionLimit());
writeIntArray(out, flat.getContextTargetOffset());
writeIntArray(out, flat.getContextTargetSize());
out.writeInt(flat.getEndTraining());
out.writeBoolean(flat.getHasContext());
out.writeInt(flat.getInputCount());
writeIntArray(out, flat.getLayerCounts());
writeIntArray(out, flat.getLayerFeedCounts());
writeIntArray(out, flat.getLayerContextCount());
writeIntArray(out, flat.getLayerIndex());
writeDoubleArray(out, flat.getLayerOutput());
out.writeInt(flat.getOutputCount());
writeIntArray(out, flat.getWeightIndex());
writeDoubleArray(out, flat.getWeights());
writeDoubleArray(out, flat.getBiasActivation());
// write activation list
out.writeInt(flat.getActivationFunctions().length);
for (final ActivationFunction af : flat.getActivationFunctions()) {
ml.shifu.shifu.core.dtrain.StringUtils.writeString(out, af.getClass().getSimpleName());
writeDoubleArray(out, af.getParams());
}
// write sub sets
Set<Integer> featureList = network.getFeatureSet();
if (featureList == null || featureList.size() == 0) {
out.writeInt(0);
} else {
out.writeInt(featureList.size());
for (Integer integer : featureList) {
out.writeInt(integer);
}
}
}
use of org.encog.engine.network.activation.ActivationFunction in project shifu by ShifuML.
the class DTrainTest method initGradient.
public Gradient initGradient(MLDataSet training) {
FlatNetwork flat = network.getFlat().clone();
// copy Propagation from encog
double[] flatSpot = new double[flat.getActivationFunctions().length];
for (int i = 0; i < flat.getActivationFunctions().length; i++) {
final ActivationFunction af = flat.getActivationFunctions()[i];
if (af instanceof ActivationSigmoid) {
flatSpot[i] = 0.1;
} else {
flatSpot[i] = 0.0;
}
}
return new Gradient(flat, training.openAdditional(), training, flatSpot, new LinearErrorFunction(), false);
}
use of org.encog.engine.network.activation.ActivationFunction in project shifu by ShifuML.
the class PersistBasicFloatNetwork method save.
/**
* {@inheritDoc}
*/
@Override
public final void save(final OutputStream os, final Object obj) {
final EncogWriteHelper out = new EncogWriteHelper(os);
final BasicFloatNetwork net = (BasicFloatNetwork) obj;
final FlatNetwork flat = net.getStructure().getFlat();
out.addSection("BASIC");
out.addSubSection("PARAMS");
out.addProperties(net.getProperties());
out.addSubSection("NETWORK");
out.writeProperty(BasicNetwork.TAG_BEGIN_TRAINING, flat.getBeginTraining());
out.writeProperty(BasicNetwork.TAG_CONNECTION_LIMIT, flat.getConnectionLimit());
out.writeProperty(BasicNetwork.TAG_CONTEXT_TARGET_OFFSET, flat.getContextTargetOffset());
out.writeProperty(BasicNetwork.TAG_CONTEXT_TARGET_SIZE, flat.getContextTargetSize());
out.writeProperty(BasicNetwork.TAG_END_TRAINING, flat.getEndTraining());
out.writeProperty(BasicNetwork.TAG_HAS_CONTEXT, flat.getHasContext());
out.writeProperty(PersistConst.INPUT_COUNT, flat.getInputCount());
out.writeProperty(BasicNetwork.TAG_LAYER_COUNTS, flat.getLayerCounts());
out.writeProperty(BasicNetwork.TAG_LAYER_FEED_COUNTS, flat.getLayerFeedCounts());
out.writeProperty(BasicNetwork.TAG_LAYER_CONTEXT_COUNT, flat.getLayerContextCount());
out.writeProperty(BasicNetwork.TAG_LAYER_INDEX, flat.getLayerIndex());
out.writeProperty(PersistConst.OUTPUT, flat.getLayerOutput());
out.writeProperty(PersistConst.OUTPUT_COUNT, flat.getOutputCount());
out.writeProperty(BasicNetwork.TAG_WEIGHT_INDEX, flat.getWeightIndex());
out.writeProperty(PersistConst.WEIGHTS, flat.getWeights());
out.writeProperty(BasicNetwork.TAG_BIAS_ACTIVATION, flat.getBiasActivation());
out.addSubSection("ACTIVATION");
for (final ActivationFunction af : flat.getActivationFunctions()) {
out.addColumn(af.getClass().getSimpleName());
for (int i = 0; i < af.getParams().length; i++) {
out.addColumn(af.getParams()[i]);
}
out.writeLine();
}
out.addSubSection("SUBSET");
Set<Integer> featureList = net.getFeatureSet();
if (featureList == null || featureList.size() == 0) {
out.writeProperty("SUBSETFEATURES", "");
} else {
String subFeaturesStr = StringUtils.join(featureList, ",");
out.writeProperty("SUBSETFEATURES", subFeaturesStr);
}
out.flush();
}
use of org.encog.engine.network.activation.ActivationFunction in project shifu by ShifuML.
the class PersistBasicFloatNetwork method read.
/**
* {@inheritDoc}
*/
@Override
public final Object read(final InputStream is) {
final BasicFloatNetwork result = new BasicFloatNetwork();
final FlatNetwork flat = new FlatNetwork();
final EncogReadHelper in = new EncogReadHelper(is);
EncogFileSection section;
while ((section = in.readNextSection()) != null) {
if (section.getSectionName().equals("BASIC") && section.getSubSectionName().equals("PARAMS")) {
final Map<String, String> params = section.parseParams();
result.getProperties().putAll(params);
}
if (section.getSectionName().equals("BASIC") && section.getSubSectionName().equals("NETWORK")) {
final Map<String, String> params = section.parseParams();
flat.setBeginTraining(EncogFileSection.parseInt(params, BasicNetwork.TAG_BEGIN_TRAINING));
flat.setConnectionLimit(EncogFileSection.parseDouble(params, BasicNetwork.TAG_CONNECTION_LIMIT));
flat.setContextTargetOffset(EncogFileSection.parseIntArray(params, BasicNetwork.TAG_CONTEXT_TARGET_OFFSET));
flat.setContextTargetSize(EncogFileSection.parseIntArray(params, BasicNetwork.TAG_CONTEXT_TARGET_SIZE));
flat.setEndTraining(EncogFileSection.parseInt(params, BasicNetwork.TAG_END_TRAINING));
flat.setHasContext(EncogFileSection.parseBoolean(params, BasicNetwork.TAG_HAS_CONTEXT));
flat.setInputCount(EncogFileSection.parseInt(params, PersistConst.INPUT_COUNT));
flat.setLayerCounts(EncogFileSection.parseIntArray(params, BasicNetwork.TAG_LAYER_COUNTS));
flat.setLayerFeedCounts(EncogFileSection.parseIntArray(params, BasicNetwork.TAG_LAYER_FEED_COUNTS));
flat.setLayerContextCount(EncogFileSection.parseIntArray(params, BasicNetwork.TAG_LAYER_CONTEXT_COUNT));
flat.setLayerIndex(EncogFileSection.parseIntArray(params, BasicNetwork.TAG_LAYER_INDEX));
flat.setLayerOutput(EncogFileSection.parseDoubleArray(params, PersistConst.OUTPUT));
flat.setLayerSums(new double[flat.getLayerOutput().length]);
flat.setOutputCount(EncogFileSection.parseInt(params, PersistConst.OUTPUT_COUNT));
flat.setWeightIndex(EncogFileSection.parseIntArray(params, BasicNetwork.TAG_WEIGHT_INDEX));
flat.setWeights(EncogFileSection.parseDoubleArray(params, PersistConst.WEIGHTS));
flat.setBiasActivation(EncogFileSection.parseDoubleArray(params, BasicNetwork.TAG_BIAS_ACTIVATION));
} else if (section.getSectionName().equals("BASIC") && section.getSubSectionName().equals("ACTIVATION")) {
int index = 0;
flat.setActivationFunctions(new ActivationFunction[flat.getLayerCounts().length]);
for (final String line : section.getLines()) {
ActivationFunction af = null;
final List<String> cols = EncogFileSection.splitColumns(line);
String name = "org.encog.engine.network.activation." + cols.get(0);
if (cols.get(0).equals("ActivationReLU")) {
name = "ml.shifu.shifu.core.dtrain.nn.ActivationReLU";
} else if (cols.get(0).equals("ActivationLeakyReLU")) {
name = "ml.shifu.shifu.core.dtrain.nn.ActivationLeakyReLU";
} else if (cols.get(0).equals("ActivationSwish")) {
name = "ml.shifu.shifu.core.dtrain.nn.ActivationSwish";
} else if (cols.get(0).equals("ActivationPTANH")) {
name = "ml.shifu.shifu.core.dtrain.nn.ActivationPTANH";
}
try {
final Class<?> clazz = Class.forName(name);
af = (ActivationFunction) clazz.newInstance();
} catch (final ClassNotFoundException e) {
throw new PersistError(e);
} catch (final InstantiationException e) {
throw new PersistError(e);
} catch (final IllegalAccessException e) {
throw new PersistError(e);
}
for (int i = 0; i < af.getParamNames().length; i++) {
af.setParam(i, CSVFormat.EG_FORMAT.parse(cols.get(i + 1)));
}
flat.getActivationFunctions()[index++] = af;
}
} else if (section.getSectionName().equals("BASIC") && section.getSubSectionName().equals("SUBSET")) {
final Map<String, String> params = section.parseParams();
String subsetStr = params.get("SUBSETFEATURES");
if (StringUtils.isBlank(subsetStr)) {
result.setFeatureSet(null);
} else {
String[] splits = subsetStr.split(",");
Set<Integer> subFeatures = new HashSet<Integer>();
for (String split : splits) {
int featureIndex = Integer.parseInt(split);
subFeatures.add(featureIndex);
}
result.setFeatureSet(subFeatures);
}
}
}
result.getStructure().setFlat(flat);
return result;
}
Aggregations