Search in sources :

Example 1 with BasicNetwork

use of org.encog.neural.networks.BasicNetwork in project shifu by ShifuML.

the class NNTrainer method buildNetwork.

@SuppressWarnings("unchecked")
public void buildNetwork() {
    network = new BasicNetwork();
    network.addLayer(new BasicLayer(new ActivationLinear(), true, trainSet.getInputSize()));
    int numLayers = (Integer) modelConfig.getParams().get(CommonConstants.NUM_HIDDEN_LAYERS);
    List<String> actFunc = (List<String>) modelConfig.getParams().get(CommonConstants.ACTIVATION_FUNC);
    List<Integer> hiddenNodeList = (List<Integer>) modelConfig.getParams().get(CommonConstants.NUM_HIDDEN_NODES);
    if (numLayers != 0 && (numLayers != actFunc.size() || numLayers != hiddenNodeList.size())) {
        throw new RuntimeException("the number of layer do not equal to the number of activation function or the function list and node list empty");
    }
    if (toLoggingProcess)
        LOG.info("    - total " + numLayers + " layers, each layers are: " + Arrays.toString(hiddenNodeList.toArray()) + " the activation function are: " + Arrays.toString(actFunc.toArray()));
    for (int i = 0; i < numLayers; i++) {
        String func = actFunc.get(i);
        Integer numHiddenNode = hiddenNodeList.get(i);
        // java 6
        if ("linear".equalsIgnoreCase(func)) {
            network.addLayer(new BasicLayer(new ActivationLinear(), true, numHiddenNode));
        } else if (func.equalsIgnoreCase("sigmoid")) {
            network.addLayer(new BasicLayer(new ActivationSigmoid(), true, numHiddenNode));
        } else if (func.equalsIgnoreCase("tanh")) {
            network.addLayer(new BasicLayer(new ActivationTANH(), true, numHiddenNode));
        } else if (func.equalsIgnoreCase("log")) {
            network.addLayer(new BasicLayer(new ActivationLOG(), true, numHiddenNode));
        } else if (func.equalsIgnoreCase("sin")) {
            network.addLayer(new BasicLayer(new ActivationSIN(), true, numHiddenNode));
        } else {
            LOG.info("Unsupported activation function: " + func + " !! Set this layer activation function to be Sigmoid ");
            network.addLayer(new BasicLayer(new ActivationSigmoid(), true, numHiddenNode));
        }
    }
    network.addLayer(new BasicLayer(new ActivationSigmoid(), false, trainSet.getIdealSize()));
    network.getStructure().finalizeStructure();
    if (!modelConfig.isFixInitialInput()) {
        network.reset();
    } else {
        int numWeight = 0;
        for (int i = 0; i < network.getLayerCount() - 1; i++) {
            numWeight = numWeight + network.getLayerTotalNeuronCount(i) * network.getLayerNeuronCount(i + 1);
        }
        LOG.info("    - You have " + numWeight + " weights to be initialize");
        loadWeightsInput(numWeight);
    }
}
Also used : ActivationLinear(org.encog.engine.network.activation.ActivationLinear) ActivationSIN(org.encog.engine.network.activation.ActivationSIN) ActivationLOG(org.encog.engine.network.activation.ActivationLOG) BasicNetwork(org.encog.neural.networks.BasicNetwork) ActivationTANH(org.encog.engine.network.activation.ActivationTANH) ActivationSigmoid(org.encog.engine.network.activation.ActivationSigmoid) ArrayList(java.util.ArrayList) List(java.util.List) BasicLayer(org.encog.neural.networks.layers.BasicLayer)

Example 2 with BasicNetwork

use of org.encog.neural.networks.BasicNetwork in project shifu by ShifuML.

the class NNTrainer method calculateMSEParallel.

public double calculateMSEParallel(BasicNetwork network, MLDataSet dataSet) {
    int numRecords = (int) dataSet.getRecordCount();
    assert numRecords > 0;
    // setup workers
    final DetermineWorkload determine = new DetermineWorkload(0, numRecords);
    // nice little workaround
    MSEWorker[] workers = new MSEWorker[determine.getThreadCount()];
    int index = 0;
    TaskGroup group = EngineConcurrency.getInstance().createTaskGroup();
    for (final IntRange r : determine.calculateWorkers()) {
        workers[index++] = new MSEWorker((BasicNetwork) network.clone(), dataSet.openAdditional(), r.getLow(), r.getHigh());
    }
    for (final MSEWorker worker : workers) {
        EngineConcurrency.getInstance().processTask(worker, group);
    }
    group.waitForComplete();
    double totalError = 0;
    for (final MSEWorker worker : workers) {
        totalError += worker.getTotalError();
    }
    return totalError / numRecords;
}
Also used : MSEWorker(ml.shifu.shifu.core.MSEWorker) BasicNetwork(org.encog.neural.networks.BasicNetwork) DetermineWorkload(org.encog.util.concurrency.DetermineWorkload) IntRange(org.encog.mathutil.IntRange) TaskGroup(org.encog.util.concurrency.TaskGroup)

Example 3 with BasicNetwork

use of org.encog.neural.networks.BasicNetwork in project shifu by ShifuML.

the class LogisticRegressionTrainer method train.

/**
 * {@inheritDoc}
 * <p>
 * no <code>regularization</code>
 * <p>
 * Regular will be provide later
 * <p>
 *
 * @throws IOException
 *             e
 */
@Override
public double train() throws IOException {
    classifier = new BasicNetwork();
    classifier.addLayer(new BasicLayer(new ActivationLinear(), true, trainSet.getInputSize()));
    classifier.addLayer(new BasicLayer(new ActivationSigmoid(), false, trainSet.getIdealSize()));
    classifier.getStructure().finalizeStructure();
    // resetParams(classifier);
    classifier.reset();
    // Propagation mlTrain = getMLTrain();
    Propagation propagation = new QuickPropagation(classifier, trainSet, (Double) modelConfig.getParams().get("LearningRate"));
    int epochs = modelConfig.getNumTrainEpochs();
    // Get convergence threshold from modelConfig.
    double threshold = modelConfig.getTrain().getConvergenceThreshold() == null ? 0.0 : modelConfig.getTrain().getConvergenceThreshold().doubleValue();
    String formatedThreshold = df.format(threshold);
    LOG.info("Using " + (Double) modelConfig.getParams().get("LearningRate") + " training rate");
    for (int i = 0; i < epochs; i++) {
        propagation.iteration();
        double trainError = propagation.getError();
        double validError = classifier.calculateError(this.validSet);
        LOG.info("Epoch #" + (i + 1) + " Train Error:" + df.format(trainError) + " Validation Error:" + df.format(validError));
        // Convergence judging.
        double avgErr = (trainError + validError) / 2;
        if (judger.judge(avgErr, threshold)) {
            LOG.info("Trainer-{}> Epoch #{} converged! Average Error: {}, Threshold: {}", trainerID, (i + 1), df.format(avgErr), formatedThreshold);
            break;
        }
    }
    propagation.finishTraining();
    LOG.info("#" + this.trainerID + " finish training");
    saveLR();
    return 0.0d;
}
Also used : BasicNetwork(org.encog.neural.networks.BasicNetwork) Propagation(org.encog.neural.networks.training.propagation.Propagation) QuickPropagation(org.encog.neural.networks.training.propagation.quick.QuickPropagation) ActivationLinear(org.encog.engine.network.activation.ActivationLinear) QuickPropagation(org.encog.neural.networks.training.propagation.quick.QuickPropagation) ActivationSigmoid(org.encog.engine.network.activation.ActivationSigmoid) BasicLayer(org.encog.neural.networks.layers.BasicLayer)

Example 4 with BasicNetwork

use of org.encog.neural.networks.BasicNetwork in project shifu by ShifuML.

the class Scorer method scoreNsData.

public ScoreObject scoreNsData(MLDataPair inputPair, Map<NSColumn, String> rawNsDataMap) {
    if (inputPair == null && !this.alg.equalsIgnoreCase(NNConstants.NN_ALG_NAME)) {
        inputPair = NormalUtils.assembleNsDataPair(binCategoryMap, noVarSelect, modelConfig, selectedColumnConfigList, rawNsDataMap, cutoff, alg);
    }
    // clear cache
    this.cachedNormDataPair.clear();
    final MLDataPair pair = inputPair;
    List<MLData> modelResults = new ArrayList<MLData>();
    List<Callable<MLData>> tasks = null;
    if (this.multiThread) {
        tasks = new ArrayList<Callable<MLData>>();
    }
    for (final BasicML model : models) {
        // TODO, check if no need 'if' condition and refactor two if for loops please
        if (model instanceof BasicFloatNetwork || model instanceof NNModel) {
            final BasicFloatNetwork network = (model instanceof BasicFloatNetwork) ? (BasicFloatNetwork) model : ((NNModel) model).getIndependentNNModel().getBasicNetworks().get(0);
            String cacheKey = featureSetToString(network.getFeatureSet());
            MLDataPair dataPair = cachedNormDataPair.get(cacheKey);
            if (dataPair == null) {
                dataPair = NormalUtils.assembleNsDataPair(binCategoryMap, noVarSelect, modelConfig, selectedColumnConfigList, rawNsDataMap, cutoff, alg, network.getFeatureSet());
                cachedNormDataPair.put(cacheKey, dataPair);
            }
            final MLDataPair networkPair = dataPair;
            /*
                 * if(network.getFeatureSet().size() != networkPair.getInput().size()) {
                 * log.error("Network and input size mismatch: Network Size = " + network.getFeatureSet().size()
                 * + "; Input Size = " + networkPair.getInput().size());
                 * continue;
                 * }
                 */
            if (System.currentTimeMillis() % 1000 == 0L) {
                log.info("Network input count = {}, while input size = {}", network.getInputCount(), networkPair.getInput().size());
            }
            final int fnlOutputHiddenLayerIndex = outputHiddenLayerIndex;
            Callable<MLData> callable = new Callable<MLData>() {

                @Override
                public MLData call() {
                    MLData finalOutput = network.compute(networkPair.getInput());
                    if (fnlOutputHiddenLayerIndex == 0) {
                        return finalOutput;
                    }
                    // append output values in hidden layer
                    double[] hiddenOutputs = network.getLayerOutput(fnlOutputHiddenLayerIndex);
                    double[] outputs = new double[finalOutput.getData().length + hiddenOutputs.length];
                    System.arraycopy(finalOutput.getData(), 0, outputs, 0, finalOutput.getData().length);
                    System.arraycopy(hiddenOutputs, 0, outputs, finalOutput.getData().length, hiddenOutputs.length);
                    return new BasicMLData(outputs);
                }
            };
            if (multiThread) {
                tasks.add(callable);
            } else {
                try {
                    modelResults.add(callable.call());
                } catch (Exception e) {
                    log.error("error in model evaluation", e);
                }
            }
        } else if (model instanceof BasicNetwork) {
            final BasicNetwork network = (BasicNetwork) model;
            final MLDataPair networkPair = NormalUtils.assembleNsDataPair(binCategoryMap, noVarSelect, modelConfig, columnConfigList, rawNsDataMap, cutoff, alg, null);
            Callable<MLData> callable = new Callable<MLData>() {

                @Override
                public MLData call() {
                    return network.compute(networkPair.getInput());
                }
            };
            if (multiThread) {
                tasks.add(callable);
            } else {
                try {
                    modelResults.add(callable.call());
                } catch (Exception e) {
                    log.error("error in model evaluation", e);
                }
            }
        } else if (model instanceof SVM) {
            final SVM svm = (SVM) model;
            if (svm.getInputCount() != pair.getInput().size()) {
                log.error("SVM and input size mismatch: SVM Size = " + svm.getInputCount() + "; Input Size = " + pair.getInput().size());
                continue;
            }
            Callable<MLData> callable = new Callable<MLData>() {

                @Override
                public MLData call() {
                    return svm.compute(pair.getInput());
                }
            };
            if (multiThread) {
                tasks.add(callable);
            } else {
                try {
                    modelResults.add(callable.call());
                } catch (Exception e) {
                    log.error("error in model evaluation", e);
                }
            }
        } else if (model instanceof LR) {
            final LR lr = (LR) model;
            if (lr.getInputCount() != pair.getInput().size()) {
                log.error("LR and input size mismatch: LR Size = " + lr.getInputCount() + "; Input Size = " + pair.getInput().size());
                continue;
            }
            Callable<MLData> callable = new Callable<MLData>() {

                @Override
                public MLData call() {
                    return lr.compute(pair.getInput());
                }
            };
            if (multiThread) {
                tasks.add(callable);
            } else {
                try {
                    modelResults.add(callable.call());
                } catch (Exception e) {
                    log.error("error in model evaluation", e);
                }
            }
        } else if (model instanceof TreeModel) {
            final TreeModel tm = (TreeModel) model;
            if (tm.getInputCount() != pair.getInput().size()) {
                throw new RuntimeException("GBDT and input size mismatch: tm input Size = " + tm.getInputCount() + "; data input Size = " + pair.getInput().size());
            }
            Callable<MLData> callable = new Callable<MLData>() {

                @Override
                public MLData call() {
                    MLData result = tm.compute(pair.getInput());
                    return result;
                }
            };
            if (multiThread) {
                tasks.add(callable);
            } else {
                try {
                    modelResults.add(callable.call());
                } catch (Exception e) {
                    log.error("error in model evaluation", e);
                }
            }
        } else if (model instanceof GenericModel) {
            Callable<MLData> callable = new Callable<MLData>() {

                @Override
                public MLData call() {
                    return ((GenericModel) model).compute(pair.getInput());
                }
            };
            if (multiThread) {
                tasks.add(callable);
            } else {
                try {
                    modelResults.add(callable.call());
                } catch (Exception e) {
                    log.error("error in model evaluation", e);
                }
            }
        } else {
            throw new RuntimeException("unsupport models");
        }
    }
    List<Double> scores = new ArrayList<Double>();
    List<Integer> rfTreeSizeList = new ArrayList<Integer>();
    SortedMap<String, Double> hiddenOutputs = null;
    if (CollectionUtils.isNotEmpty(modelResults) || CollectionUtils.isNotEmpty(tasks)) {
        int modelSize = modelResults.size() > 0 ? modelResults.size() : tasks.size();
        if (modelSize != this.models.size()) {
            log.error("Get model results size doesn't match with models size.");
            return null;
        }
        if (multiThread) {
            modelResults = this.executorManager.submitTasksAndWaitResults(tasks);
        } else {
        // not multi-thread, modelResults is directly being populated in callable.call
        }
        if (this.outputHiddenLayerIndex != 0) {
            hiddenOutputs = new TreeMap<String, Double>(new Comparator<String>() {

                @Override
                public int compare(String o1, String o2) {
                    String[] split1 = o1.split("_");
                    String[] split2 = o2.split("_");
                    int model1Index = Integer.parseInt(split1[1]);
                    int model2Index = Integer.parseInt(split2[1]);
                    if (model1Index > model2Index) {
                        return 1;
                    } else if (model1Index < model2Index) {
                        return -1;
                    } else {
                        int hidden1Index = Integer.parseInt(split1[2]);
                        int hidden2Index = Integer.parseInt(split2[2]);
                        if (hidden1Index > hidden2Index) {
                            return 1;
                        } else if (hidden1Index < hidden2Index) {
                            return -1;
                        } else {
                            int hidden11Index = Integer.parseInt(split1[3]);
                            int hidden22Index = Integer.parseInt(split2[3]);
                            return Integer.valueOf(hidden11Index).compareTo(Integer.valueOf(hidden22Index));
                        }
                    }
                }
            });
        }
        for (int i = 0; i < this.models.size(); i++) {
            BasicML model = this.models.get(i);
            MLData score = modelResults.get(i);
            if (model instanceof BasicNetwork || model instanceof NNModel) {
                if (modelConfig != null && modelConfig.isRegression()) {
                    scores.add(toScore(score.getData(0)));
                    if (this.outputHiddenLayerIndex != 0) {
                        for (int j = 1; j < score.getData().length; j++) {
                            hiddenOutputs.put("model_" + i + "_" + this.outputHiddenLayerIndex + "_" + (j - 1), score.getData()[j]);
                        }
                    }
                } else if (modelConfig != null && modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll()) {
                    // if one vs all classification
                    scores.add(toScore(score.getData(0)));
                } else {
                    double[] outputs = score.getData();
                    for (double d : outputs) {
                        scores.add(toScore(d));
                    }
                }
            } else if (model instanceof SVM) {
                scores.add(toScore(score.getData(0)));
            } else if (model instanceof LR) {
                scores.add(toScore(score.getData(0)));
            } else if (model instanceof TreeModel) {
                if (modelConfig.isClassification() && !modelConfig.getTrain().isOneVsAll()) {
                    double[] scoreArray = score.getData();
                    for (double sc : scoreArray) {
                        scores.add(sc);
                    }
                } else {
                    // if one vs all multiple classification or regression
                    scores.add(toScore(score.getData(0)));
                }
                final TreeModel tm = (TreeModel) model;
                // regression for RF
                if (!tm.isClassfication() && !tm.isGBDT()) {
                    rfTreeSizeList.add(tm.getTrees().size());
                }
            } else if (model instanceof GenericModel) {
                scores.add(toScore(score.getData(0)));
            } else {
                throw new RuntimeException("unsupport models");
            }
        }
    }
    Integer tag = Constants.DEFAULT_IDEAL_VALUE;
    if (scores.size() == 0 && System.currentTimeMillis() % 100 == 0) {
        log.warn("No Scores Calculated...");
    }
    return new ScoreObject(scores, tag, rfTreeSizeList, hiddenOutputs);
}
Also used : MLDataPair(org.encog.ml.data.MLDataPair) ArrayList(java.util.ArrayList) BasicML(org.encog.ml.BasicML) SVM(org.encog.ml.svm.SVM) Callable(java.util.concurrent.Callable) Comparator(java.util.Comparator) BasicMLData(org.encog.ml.data.basic.BasicMLData) BasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork) BasicMLData(org.encog.ml.data.basic.BasicMLData) MLData(org.encog.ml.data.MLData) ScoreObject(ml.shifu.shifu.container.ScoreObject) BasicNetwork(org.encog.neural.networks.BasicNetwork)

Example 5 with BasicNetwork

use of org.encog.neural.networks.BasicNetwork in project shifu by ShifuML.

the class NNMaster method initWeights.

@SuppressWarnings({ "unchecked" })
private NNParams initWeights() {
    NNParams params = new NNParams();
    boolean isLinearTarget = CommonUtils.isLinearTarget(modelConfig, columnConfigList);
    int[] inputAndOutput = DTrainUtils.getInputOutputCandidateCounts(modelConfig.getNormalizeType(), this.columnConfigList);
    int featureInputsCnt = DTrainUtils.getFeatureInputsCnt(modelConfig, this.columnConfigList, new HashSet<Integer>(this.subFeatures));
    @SuppressWarnings("unused") int inputNodeCount = inputAndOutput[0] == 0 ? inputAndOutput[2] : inputAndOutput[0];
    // if is one vs all classification, outputNodeCount is set to 1, if classes=2, outputNodeCount is also 1
    int classes = modelConfig.getTags().size();
    int outputNodeCount = (isLinearTarget || modelConfig.isRegression()) ? inputAndOutput[1] : (modelConfig.getTrain().isOneVsAll() ? inputAndOutput[1] : (classes == 2 ? 1 : classes));
    int numLayers = (Integer) validParams.get(CommonConstants.NUM_HIDDEN_LAYERS);
    List<String> actFunc = (List<String>) validParams.get(CommonConstants.ACTIVATION_FUNC);
    List<Integer> hiddenNodeList = (List<Integer>) validParams.get(CommonConstants.NUM_HIDDEN_NODES);
    String outputActivationFunc = (String) validParams.get(CommonConstants.OUTPUT_ACTIVATION_FUNC);
    BasicNetwork network = DTrainUtils.generateNetwork(featureInputsCnt, outputNodeCount, numLayers, actFunc, hiddenNodeList, true, this.dropoutRate, this.wgtInit, CommonUtils.isLinearTarget(modelConfig, columnConfigList), outputActivationFunc);
    this.flatNetwork = (FloatFlatNetwork) network.getFlat();
    params.setTrainError(0);
    params.setTestError(0);
    // prevent null point
    params.setGradients(new double[0]);
    params.setWeights(network.getFlat().getWeights());
    return params;
}
Also used : BasicNetwork(org.encog.neural.networks.BasicNetwork) ArrayList(java.util.ArrayList) List(java.util.List)

Aggregations

BasicNetwork (org.encog.neural.networks.BasicNetwork)14 ArrayList (java.util.ArrayList)6 BasicML (org.encog.ml.BasicML)5 FlatNetwork (org.encog.neural.flat.FlatNetwork)5 File (java.io.File)4 BasicLayer (org.encog.neural.networks.layers.BasicLayer)4 List (java.util.List)3 ActivationSigmoid (org.encog.engine.network.activation.ActivationSigmoid)3 Test (org.testng.annotations.Test)3 NNMaster (ml.shifu.shifu.core.dtrain.nn.NNMaster)2 ActivationLinear (org.encog.engine.network.activation.ActivationLinear)2 MLDataPair (org.encog.ml.data.MLDataPair)2 Comparator (java.util.Comparator)1 Callable (java.util.concurrent.Callable)1 ScoreObject (ml.shifu.shifu.container.ScoreObject)1 ModelConfig (ml.shifu.shifu.container.obj.ModelConfig)1 MSEWorker (ml.shifu.shifu.core.MSEWorker)1 NNTrainer (ml.shifu.shifu.core.alg.NNTrainer)1 BasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)1 FloatFlatNetwork (ml.shifu.shifu.core.dtrain.dataset.FloatFlatNetwork)1