use of org.encog.neural.networks.BasicNetwork in project shifu by ShifuML.
the class NNTrainer method buildNetwork.
@SuppressWarnings("unchecked")
public void buildNetwork() {
network = new BasicNetwork();
network.addLayer(new BasicLayer(new ActivationLinear(), true, trainSet.getInputSize()));
int numLayers = (Integer) modelConfig.getParams().get(CommonConstants.NUM_HIDDEN_LAYERS);
List<String> actFunc = (List<String>) modelConfig.getParams().get(CommonConstants.ACTIVATION_FUNC);
List<Integer> hiddenNodeList = (List<Integer>) modelConfig.getParams().get(CommonConstants.NUM_HIDDEN_NODES);
if (numLayers != 0 && (numLayers != actFunc.size() || numLayers != hiddenNodeList.size())) {
throw new RuntimeException("the number of layer do not equal to the number of activation function or the function list and node list empty");
}
if (toLoggingProcess)
LOG.info(" - total " + numLayers + " layers, each layers are: " + Arrays.toString(hiddenNodeList.toArray()) + " the activation function are: " + Arrays.toString(actFunc.toArray()));
for (int i = 0; i < numLayers; i++) {
String func = actFunc.get(i);
Integer numHiddenNode = hiddenNodeList.get(i);
// java 6
if ("linear".equalsIgnoreCase(func)) {
network.addLayer(new BasicLayer(new ActivationLinear(), true, numHiddenNode));
} else if (func.equalsIgnoreCase("sigmoid")) {
network.addLayer(new BasicLayer(new ActivationSigmoid(), true, numHiddenNode));
} else if (func.equalsIgnoreCase("tanh")) {
network.addLayer(new BasicLayer(new ActivationTANH(), true, numHiddenNode));
} else if (func.equalsIgnoreCase("log")) {
network.addLayer(new BasicLayer(new ActivationLOG(), true, numHiddenNode));
} else if (func.equalsIgnoreCase("sin")) {
network.addLayer(new BasicLayer(new ActivationSIN(), true, numHiddenNode));
} else {
LOG.info("Unsupported activation function: " + func + " !! Set this layer activation function to be Sigmoid ");
network.addLayer(new BasicLayer(new ActivationSigmoid(), true, numHiddenNode));
}
}
network.addLayer(new BasicLayer(new ActivationSigmoid(), false, trainSet.getIdealSize()));
network.getStructure().finalizeStructure();
if (!modelConfig.isFixInitialInput()) {
network.reset();
} else {
int numWeight = 0;
for (int i = 0; i < network.getLayerCount() - 1; i++) {
numWeight = numWeight + network.getLayerTotalNeuronCount(i) * network.getLayerNeuronCount(i + 1);
}
LOG.info(" - You have " + numWeight + " weights to be initialize");
loadWeightsInput(numWeight);
}
}
use of org.encog.neural.networks.BasicNetwork in project shifu by ShifuML.
the class NNTrainer method calculateMSEParallel.
public double calculateMSEParallel(BasicNetwork network, MLDataSet dataSet) {
int numRecords = (int) dataSet.getRecordCount();
assert numRecords > 0;
// setup workers
final DetermineWorkload determine = new DetermineWorkload(0, numRecords);
// nice little workaround
MSEWorker[] workers = new MSEWorker[determine.getThreadCount()];
int index = 0;
TaskGroup group = EngineConcurrency.getInstance().createTaskGroup();
for (final IntRange r : determine.calculateWorkers()) {
workers[index++] = new MSEWorker((BasicNetwork) network.clone(), dataSet.openAdditional(), r.getLow(), r.getHigh());
}
for (final MSEWorker worker : workers) {
EngineConcurrency.getInstance().processTask(worker, group);
}
group.waitForComplete();
double totalError = 0;
for (final MSEWorker worker : workers) {
totalError += worker.getTotalError();
}
return totalError / numRecords;
}
use of org.encog.neural.networks.BasicNetwork in project shifu by ShifuML.
the class LogisticRegressionTrainer method train.
/**
* {@inheritDoc}
* <p>
* no <code>regularization</code>
* <p>
* Regular will be provide later
* <p>
*
* @throws IOException
* e
*/
@Override
public double train() throws IOException {
classifier = new BasicNetwork();
classifier.addLayer(new BasicLayer(new ActivationLinear(), true, trainSet.getInputSize()));
classifier.addLayer(new BasicLayer(new ActivationSigmoid(), false, trainSet.getIdealSize()));
classifier.getStructure().finalizeStructure();
// resetParams(classifier);
classifier.reset();
// Propagation mlTrain = getMLTrain();
Propagation propagation = new QuickPropagation(classifier, trainSet, (Double) modelConfig.getParams().get("LearningRate"));
int epochs = modelConfig.getNumTrainEpochs();
// Get convergence threshold from modelConfig.
double threshold = modelConfig.getTrain().getConvergenceThreshold() == null ? 0.0 : modelConfig.getTrain().getConvergenceThreshold().doubleValue();
String formatedThreshold = df.format(threshold);
LOG.info("Using " + (Double) modelConfig.getParams().get("LearningRate") + " training rate");
for (int i = 0; i < epochs; i++) {
propagation.iteration();
double trainError = propagation.getError();
double validError = classifier.calculateError(this.validSet);
LOG.info("Epoch #" + (i + 1) + " Train Error:" + df.format(trainError) + " Validation Error:" + df.format(validError));
// Convergence judging.
double avgErr = (trainError + validError) / 2;
if (judger.judge(avgErr, threshold)) {
LOG.info("Trainer-{}> Epoch #{} converged! Average Error: {}, Threshold: {}", trainerID, (i + 1), df.format(avgErr), formatedThreshold);
break;
}
}
propagation.finishTraining();
LOG.info("#" + this.trainerID + " finish training");
saveLR();
return 0.0d;
}
use of org.encog.neural.networks.BasicNetwork in project shifu by ShifuML.
the class Scorer method scoreNsData.
public ScoreObject scoreNsData(MLDataPair inputPair, Map<NSColumn, String> rawNsDataMap) {
if (inputPair == null && !this.alg.equalsIgnoreCase(NNConstants.NN_ALG_NAME)) {
inputPair = NormalUtils.assembleNsDataPair(binCategoryMap, noVarSelect, modelConfig, selectedColumnConfigList, rawNsDataMap, cutoff, alg);
}
// clear cache
this.cachedNormDataPair.clear();
final MLDataPair pair = inputPair;
List<MLData> modelResults = new ArrayList<MLData>();
List<Callable<MLData>> tasks = null;
if (this.multiThread) {
tasks = new ArrayList<Callable<MLData>>();
}
for (final BasicML model : models) {
// TODO, check if no need 'if' condition and refactor two if for loops please
if (model instanceof BasicFloatNetwork || model instanceof NNModel) {
final BasicFloatNetwork network = (model instanceof BasicFloatNetwork) ? (BasicFloatNetwork) model : ((NNModel) model).getIndependentNNModel().getBasicNetworks().get(0);
String cacheKey = featureSetToString(network.getFeatureSet());
MLDataPair dataPair = cachedNormDataPair.get(cacheKey);
if (dataPair == null) {
dataPair = NormalUtils.assembleNsDataPair(binCategoryMap, noVarSelect, modelConfig, selectedColumnConfigList, rawNsDataMap, cutoff, alg, network.getFeatureSet());
cachedNormDataPair.put(cacheKey, dataPair);
}
final MLDataPair networkPair = dataPair;
/*
* if(network.getFeatureSet().size() != networkPair.getInput().size()) {
* log.error("Network and input size mismatch: Network Size = " + network.getFeatureSet().size()
* + "; Input Size = " + networkPair.getInput().size());
* continue;
* }
*/
if (System.currentTimeMillis() % 1000 == 0L) {
log.info("Network input count = {}, while input size = {}", network.getInputCount(), networkPair.getInput().size());
}
final int fnlOutputHiddenLayerIndex = outputHiddenLayerIndex;
Callable<MLData> callable = new Callable<MLData>() {
@Override
public MLData call() {
MLData finalOutput = network.compute(networkPair.getInput());
if (fnlOutputHiddenLayerIndex == 0) {
return finalOutput;
}
// append output values in hidden layer
double[] hiddenOutputs = network.getLayerOutput(fnlOutputHiddenLayerIndex);
double[] outputs = new double[finalOutput.getData().length + hiddenOutputs.length];
System.arraycopy(finalOutput.getData(), 0, outputs, 0, finalOutput.getData().length);
System.arraycopy(hiddenOutputs, 0, outputs, finalOutput.getData().length, hiddenOutputs.length);
return new BasicMLData(outputs);
}
};
if (multiThread) {
tasks.add(callable);
} else {
try {
modelResults.add(callable.call());
} catch (Exception e) {
log.error("error in model evaluation", e);
}
}
} else if (model instanceof BasicNetwork) {
final BasicNetwork network = (BasicNetwork) model;
final MLDataPair networkPair = NormalUtils.assembleNsDataPair(binCategoryMap, noVarSelect, modelConfig, columnConfigList, rawNsDataMap, cutoff, alg, null);
Callable<MLData> callable = new Callable<MLData>() {
@Override
public MLData call() {
return network.compute(networkPair.getInput());
}
};
if (multiThread) {
tasks.add(callable);
} else {
try {
modelResults.add(callable.call());
} catch (Exception e) {
log.error("error in model evaluation", e);
}
}
} else if (model instanceof SVM) {
final SVM svm = (SVM) model;
if (svm.getInputCount() != pair.getInput().size()) {
log.error("SVM and input size mismatch: SVM Size = " + svm.getInputCount() + "; Input Size = " + pair.getInput().size());
continue;
}
Callable<MLData> callable = new Callable<MLData>() {
@Override
public MLData call() {
return svm.compute(pair.getInput());
}
};
if (multiThread) {
tasks.add(callable);
} else {
try {
modelResults.add(callable.call());
} catch (Exception e) {
log.error("error in model evaluation", e);
}
}
} else if (model instanceof LR) {
final LR lr = (LR) model;
if (lr.getInputCount() != pair.getInput().size()) {
log.error("LR and input size mismatch: LR Size = " + lr.getInputCount() + "; Input Size = " + pair.getInput().size());
continue;
}
Callable<MLData> callable = new Callable<MLData>() {
@Override
public MLData call() {
return lr.compute(pair.getInput());
}
};
if (multiThread) {
tasks.add(callable);
} else {
try {
modelResults.add(callable.call());
} catch (Exception e) {
log.error("error in model evaluation", e);
}
}
} else if (model instanceof TreeModel) {
final TreeModel tm = (TreeModel) model;
if (tm.getInputCount() != pair.getInput().size()) {
throw new RuntimeException("GBDT and input size mismatch: tm input Size = " + tm.getInputCount() + "; data input Size = " + pair.getInput().size());
}
Callable<MLData> callable = new Callable<MLData>() {
@Override
public MLData call() {
MLData result = tm.compute(pair.getInput());
return result;
}
};
if (multiThread) {
tasks.add(callable);
} else {
try {
modelResults.add(callable.call());
} catch (Exception e) {
log.error("error in model evaluation", e);
}
}
} else if (model instanceof GenericModel) {
Callable<MLData> callable = new Callable<MLData>() {
@Override
public MLData call() {
return ((GenericModel) model).compute(pair.getInput());
}
};
if (multiThread) {
tasks.add(callable);
} else {
try {
modelResults.add(callable.call());
} catch (Exception e) {
log.error("error in model evaluation", e);
}
}
} else {
throw new RuntimeException("unsupport models");
}
}
List<Double> scores = new ArrayList<Double>();
List<Integer> rfTreeSizeList = new ArrayList<Integer>();
SortedMap<String, Double> hiddenOutputs = null;
if (CollectionUtils.isNotEmpty(modelResults) || CollectionUtils.isNotEmpty(tasks)) {
int modelSize = modelResults.size() > 0 ? modelResults.size() : tasks.size();
if (modelSize != this.models.size()) {
log.error("Get model results size doesn't match with models size.");
return null;
}
if (multiThread) {
modelResults = this.executorManager.submitTasksAndWaitResults(tasks);
} else {
// not multi-thread, modelResults is directly being populated in callable.call
}
if (this.outputHiddenLayerIndex != 0) {
hiddenOutputs = new TreeMap<String, Double>(new Comparator<String>() {
@Override
public int compare(String o1, String o2) {
String[] split1 = o1.split("_");
String[] split2 = o2.split("_");
int model1Index = Integer.parseInt(split1[1]);
int model2Index = Integer.parseInt(split2[1]);
if (model1Index > model2Index) {
return 1;
} else if (model1Index < model2Index) {
return -1;
} else {
int hidden1Index = Integer.parseInt(split1[2]);
int hidden2Index = Integer.parseInt(split2[2]);
if (hidden1Index > hidden2Index) {
return 1;
} else if (hidden1Index < hidden2Index) {
return -1;
} else {
int hidden11Index = Integer.parseInt(split1[3]);
int hidden22Index = Integer.parseInt(split2[3]);
return Integer.valueOf(hidden11Index).compareTo(Integer.valueOf(hidden22Index));
}
}
}
});
}
for (int i = 0; i < this.models.size(); i++) {
BasicML model = this.models.get(i);
MLData score = modelResults.get(i);
if (model instanceof BasicNetwork || model instanceof NNModel) {
if (modelConfig != null && modelConfig.isRegression()) {
scores.add(toScore(score.getData(0)));
if (this.outputHiddenLayerIndex != 0) {
for (int j = 1; j < score.getData().length; j++) {
hiddenOutputs.put("model_" + i + "_" + this.outputHiddenLayerIndex + "_" + (j - 1), score.getData()[j]);
}
}
} else if (modelConfig != null && modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll()) {
// if one vs all classification
scores.add(toScore(score.getData(0)));
} else {
double[] outputs = score.getData();
for (double d : outputs) {
scores.add(toScore(d));
}
}
} else if (model instanceof SVM) {
scores.add(toScore(score.getData(0)));
} else if (model instanceof LR) {
scores.add(toScore(score.getData(0)));
} else if (model instanceof TreeModel) {
if (modelConfig.isClassification() && !modelConfig.getTrain().isOneVsAll()) {
double[] scoreArray = score.getData();
for (double sc : scoreArray) {
scores.add(sc);
}
} else {
// if one vs all multiple classification or regression
scores.add(toScore(score.getData(0)));
}
final TreeModel tm = (TreeModel) model;
// regression for RF
if (!tm.isClassfication() && !tm.isGBDT()) {
rfTreeSizeList.add(tm.getTrees().size());
}
} else if (model instanceof GenericModel) {
scores.add(toScore(score.getData(0)));
} else {
throw new RuntimeException("unsupport models");
}
}
}
Integer tag = Constants.DEFAULT_IDEAL_VALUE;
if (scores.size() == 0 && System.currentTimeMillis() % 100 == 0) {
log.warn("No Scores Calculated...");
}
return new ScoreObject(scores, tag, rfTreeSizeList, hiddenOutputs);
}
use of org.encog.neural.networks.BasicNetwork in project shifu by ShifuML.
the class NNMaster method initWeights.
@SuppressWarnings({ "unchecked" })
private NNParams initWeights() {
NNParams params = new NNParams();
boolean isLinearTarget = CommonUtils.isLinearTarget(modelConfig, columnConfigList);
int[] inputAndOutput = DTrainUtils.getInputOutputCandidateCounts(modelConfig.getNormalizeType(), this.columnConfigList);
int featureInputsCnt = DTrainUtils.getFeatureInputsCnt(modelConfig, this.columnConfigList, new HashSet<Integer>(this.subFeatures));
@SuppressWarnings("unused") int inputNodeCount = inputAndOutput[0] == 0 ? inputAndOutput[2] : inputAndOutput[0];
// if is one vs all classification, outputNodeCount is set to 1, if classes=2, outputNodeCount is also 1
int classes = modelConfig.getTags().size();
int outputNodeCount = (isLinearTarget || modelConfig.isRegression()) ? inputAndOutput[1] : (modelConfig.getTrain().isOneVsAll() ? inputAndOutput[1] : (classes == 2 ? 1 : classes));
int numLayers = (Integer) validParams.get(CommonConstants.NUM_HIDDEN_LAYERS);
List<String> actFunc = (List<String>) validParams.get(CommonConstants.ACTIVATION_FUNC);
List<Integer> hiddenNodeList = (List<Integer>) validParams.get(CommonConstants.NUM_HIDDEN_NODES);
String outputActivationFunc = (String) validParams.get(CommonConstants.OUTPUT_ACTIVATION_FUNC);
BasicNetwork network = DTrainUtils.generateNetwork(featureInputsCnt, outputNodeCount, numLayers, actFunc, hiddenNodeList, true, this.dropoutRate, this.wgtInit, CommonUtils.isLinearTarget(modelConfig, columnConfigList), outputActivationFunc);
this.flatNetwork = (FloatFlatNetwork) network.getFlat();
params.setTrainError(0);
params.setTestError(0);
// prevent null point
params.setGradients(new double[0]);
params.setWeights(network.getFlat().getWeights());
return params;
}
Aggregations