Search in sources :

Example 6 with NNTrainer

use of ml.shifu.shifu.core.alg.NNTrainer in project shifu by ShifuML.

the class NNTrainerTest method testExceptionWhileSetupModel.

@Test(expectedExceptions = RuntimeException.class)
public void testExceptionWhileSetupModel() throws IOException {
    ModelConfig config = ModelConfig.createInitModelConfig(".", ALGORITHM.NN, ".", false);
    config.getTrain().getParams().put("Propagation", "Q");
    config.getTrain().getParams().put("NumHiddenLayers", 2);
    config.getTrain().getParams().put("LearningRate", 0.1);
    List<Integer> nodes = new ArrayList<Integer>();
    nodes.add(3);
    nodes.add(3);
    nodes.add(3);
    List<String> func = new ArrayList<String>();
    func.add("tanh");
    config.getTrain().getParams().put("NumHiddenNodes", nodes);
    config.getTrain().getParams().put("ActivationFunc", func);
    config.getTrain().setNumTrainEpochs(50);
    NNTrainer trainer = new NNTrainer(config, 0, false);
    try {
        trainer.setDataSet(xor_Trainset);
    } catch (IOException e) {
    }
    trainer.buildNetwork();
}
Also used : ModelConfig(ml.shifu.shifu.container.obj.ModelConfig) NNTrainer(ml.shifu.shifu.core.alg.NNTrainer) ArrayList(java.util.ArrayList) IOException(java.io.IOException) Test(org.testng.annotations.Test)

Example 7 with NNTrainer

use of ml.shifu.shifu.core.alg.NNTrainer in project shifu by ShifuML.

the class ScorerTest method setup.

@BeforeClass
public void setup() throws IOException {
    modelConfig = ModelConfig.createInitModelConfig(".", ALGORITHM.NN, ".", false);
    modelConfig.getTrain().getParams().put("Propagation", "B");
    modelConfig.getTrain().getParams().put("NumHiddenLayers", 2);
    modelConfig.getTrain().getParams().put("LearningRate", 0.5);
    List<Integer> nodes = new ArrayList<Integer>();
    nodes.add(3);
    nodes.add(4);
    List<String> func = new ArrayList<String>();
    func.add("linear");
    func.add("tanh");
    modelConfig.getTrain().getParams().put("NumHiddenNodes", nodes);
    modelConfig.getTrain().getParams().put("ActivationFunc", func);
    NNTrainer trainer = new NNTrainer(modelConfig, 0, false);
    double[] input = { 0., 0. };
    double[] ideal = { 1. };
    MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
    set.add(pair);
    input = new double[] { 0., 1. };
    ideal = new double[] { 0. };
    pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
    set.add(pair);
    input = new double[] { 1., 0. };
    ideal = new double[] { 0. };
    pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
    set.add(pair);
    input = new double[] { 1., 1. };
    ideal = new double[] { 1. };
    pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
    set.add(pair);
    trainer.setTrainSet(set);
    trainer.setValidSet(set);
    trainer.train();
    modelConfig.getTrain().setAlgorithm("SVM");
    modelConfig.getTrain().getParams().put("Kernel", "rbf");
    modelConfig.getTrain().getParams().put("Const", 0.1);
    modelConfig.getTrain().getParams().put("Gamma", 1.0);
    modelConfig.getVarSelect().setFilterNum(2);
    SVMTrainer svm = new SVMTrainer(modelConfig, 1, false);
    svm.setTrainSet(set);
    svm.setValidSet(set);
    svm.train();
    models.add(trainer.getNetwork());
    models.add(svm.getSVM());
}
Also used : BasicMLDataPair(org.encog.ml.data.basic.BasicMLDataPair) MLDataPair(org.encog.ml.data.MLDataPair) NNTrainer(ml.shifu.shifu.core.alg.NNTrainer) BasicMLDataPair(org.encog.ml.data.basic.BasicMLDataPair) BasicMLData(org.encog.ml.data.basic.BasicMLData) ArrayList(java.util.ArrayList) SVMTrainer(ml.shifu.shifu.core.alg.SVMTrainer) BeforeClass(org.testng.annotations.BeforeClass)

Aggregations

NNTrainer (ml.shifu.shifu.core.alg.NNTrainer)7 ArrayList (java.util.ArrayList)4 ModelConfig (ml.shifu.shifu.container.obj.ModelConfig)3 SVMTrainer (ml.shifu.shifu.core.alg.SVMTrainer)3 MLDataPair (org.encog.ml.data.MLDataPair)3 BasicMLDataPair (org.encog.ml.data.basic.BasicMLDataPair)3 Test (org.testng.annotations.Test)3 IOException (java.io.IOException)2 AbstractTrainer (ml.shifu.shifu.core.AbstractTrainer)2 LogisticRegressionTrainer (ml.shifu.shifu.core.alg.LogisticRegressionTrainer)2 MLDataSet (org.encog.ml.data.MLDataSet)2 BasicMLData (org.encog.ml.data.basic.BasicMLData)2 BasicMLDataSet (org.encog.ml.data.basic.BasicMLDataSet)2 File (java.io.File)1 Scanner (java.util.Scanner)1 ShifuException (ml.shifu.shifu.exception.ShifuException)1 AkkaActorInputMessage (ml.shifu.shifu.message.AkkaActorInputMessage)1 BasicNetwork (org.encog.neural.networks.BasicNetwork)1 BeforeClass (org.testng.annotations.BeforeClass)1