use of ml.shifu.shifu.core.MSEWorker in project shifu by ShifuML.
the class NNTrainer method calculateMSEParallel.
public double calculateMSEParallel(BasicNetwork network, MLDataSet dataSet) {
int numRecords = (int) dataSet.getRecordCount();
assert numRecords > 0;
// setup workers
final DetermineWorkload determine = new DetermineWorkload(0, numRecords);
// nice little workaround
MSEWorker[] workers = new MSEWorker[determine.getThreadCount()];
int index = 0;
TaskGroup group = EngineConcurrency.getInstance().createTaskGroup();
for (final IntRange r : determine.calculateWorkers()) {
workers[index++] = new MSEWorker((BasicNetwork) network.clone(), dataSet.openAdditional(), r.getLow(), r.getHigh());
}
for (final MSEWorker worker : workers) {
EngineConcurrency.getInstance().processTask(worker, group);
}
group.waitForComplete();
double totalError = 0;
for (final MSEWorker worker : workers) {
totalError += worker.getTotalError();
}
return totalError / numRecords;
}
Aggregations