use of ml.shifu.shifu.core.dtrain.dataset.FloatFlatNetwork in project shifu by ShifuML.
the class SubGradient method calculateError.
/**
* Calculate the error for this neural network. The error is calculated
* using root-mean-square(RMS).
*
* @param ec
* The error computation logic
* @return The error percentage.
*/
public final double calculateError(ErrorCalculation ec) {
final double[] actual = new double[this.getNetwork().getOutputCount()];
final FloatMLDataPair pair = BasicFloatMLDataPair.createPair(testing.getInputSize(), testing.getIdealSize());
for (long i = testLow; i <= testHigh; i++) {
synchronized (this.owner) {
if (this.isCrossOver) {
// 3:1 to select testing data set, tmp hard code, TODO fix hard code issue
if ((i + seed) % 4 < 3) {
this.testing.getRecord(i, pair);
} else {
long trainingSize = this.training.getRecordCount();
// it's ok to take data from all training set
if (i < trainingSize) {
this.training.getRecord(i, pair);
} else {
this.training.getRecord(i % trainingSize, pair);
}
}
} else {
this.testing.getRecord(i, pair);
}
}
((FloatFlatNetwork) this.getNetwork()).compute(pair.getInputArray(), actual);
// copy float idea array to double for api compatibility
if (doubleIdeal == null) {
doubleIdeal = new double[pair.getIdealArray().length];
}
for (int j = 0; j < doubleIdeal.length; j++) {
doubleIdeal[j] = pair.getIdealArray()[j];
}
synchronized (ec) {
ec.updateError(actual, doubleIdeal, pair.getSignificance());
}
}
return -1;
}
use of ml.shifu.shifu.core.dtrain.dataset.FloatFlatNetwork in project shifu by ShifuML.
the class AbstractNNWorker method initGradient.
@SuppressWarnings("unchecked")
private void initGradient(FloatMLDataSet training, FloatMLDataSet testing, double[] weights, boolean isCrossOver) {
int numLayers = (Integer) this.validParams.get(CommonConstants.NUM_HIDDEN_LAYERS);
List<String> actFunc = (List<String>) this.validParams.get(CommonConstants.ACTIVATION_FUNC);
List<Integer> hiddenNodeList = (List<Integer>) this.validParams.get(CommonConstants.NUM_HIDDEN_NODES);
String outputActivationFunc = (String) validParams.get(CommonConstants.OUTPUT_ACTIVATION_FUNC);
BasicNetwork network = DTrainUtils.generateNetwork(this.featureInputsCnt, this.outputNodeCount, numLayers, actFunc, hiddenNodeList, false, this.dropoutRate, this.wgtInit, CommonUtils.isLinearTarget(modelConfig, columnConfigList), outputActivationFunc);
// use the weights from master
network.getFlat().setWeights(weights);
FlatNetwork flat = network.getFlat();
// copy Propagation from encog, fix flat spot problem
double[] flatSpot = new double[flat.getActivationFunctions().length];
for (int i = 0; i < flat.getActivationFunctions().length; i++) {
flatSpot[i] = flat.getActivationFunctions()[i] instanceof ActivationSigmoid ? 0.1 : 0.0;
}
LOG.info("Gradient computing thread count is {}.", modelConfig.getTrain().getWorkerThreadCount());
this.gradient = new ParallelGradient((FloatFlatNetwork) flat, training, testing, flatSpot, new LinearErrorFunction(), isCrossOver, modelConfig.getTrain().getWorkerThreadCount(), this.lossStr, this.batchs);
}
Aggregations