use of ml.shifu.shifu.core.alg.NNTrainer in project shifu by ShifuML.
the class NNTrainerTest method testExceptionWhileSetupModel.
@Test(expectedExceptions = RuntimeException.class)
public void testExceptionWhileSetupModel() throws IOException {
ModelConfig config = ModelConfig.createInitModelConfig(".", ALGORITHM.NN, ".", false);
config.getTrain().getParams().put("Propagation", "Q");
config.getTrain().getParams().put("NumHiddenLayers", 2);
config.getTrain().getParams().put("LearningRate", 0.1);
List<Integer> nodes = new ArrayList<Integer>();
nodes.add(3);
nodes.add(3);
nodes.add(3);
List<String> func = new ArrayList<String>();
func.add("tanh");
config.getTrain().getParams().put("NumHiddenNodes", nodes);
config.getTrain().getParams().put("ActivationFunc", func);
config.getTrain().setNumTrainEpochs(50);
NNTrainer trainer = new NNTrainer(config, 0, false);
try {
trainer.setDataSet(xor_Trainset);
} catch (IOException e) {
}
trainer.buildNetwork();
}
use of ml.shifu.shifu.core.alg.NNTrainer in project shifu by ShifuML.
the class ScorerTest method setup.
@BeforeClass
public void setup() throws IOException {
modelConfig = ModelConfig.createInitModelConfig(".", ALGORITHM.NN, ".", false);
modelConfig.getTrain().getParams().put("Propagation", "B");
modelConfig.getTrain().getParams().put("NumHiddenLayers", 2);
modelConfig.getTrain().getParams().put("LearningRate", 0.5);
List<Integer> nodes = new ArrayList<Integer>();
nodes.add(3);
nodes.add(4);
List<String> func = new ArrayList<String>();
func.add("linear");
func.add("tanh");
modelConfig.getTrain().getParams().put("NumHiddenNodes", nodes);
modelConfig.getTrain().getParams().put("ActivationFunc", func);
NNTrainer trainer = new NNTrainer(modelConfig, 0, false);
double[] input = { 0., 0. };
double[] ideal = { 1. };
MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
set.add(pair);
input = new double[] { 0., 1. };
ideal = new double[] { 0. };
pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
set.add(pair);
input = new double[] { 1., 0. };
ideal = new double[] { 0. };
pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
set.add(pair);
input = new double[] { 1., 1. };
ideal = new double[] { 1. };
pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
set.add(pair);
trainer.setTrainSet(set);
trainer.setValidSet(set);
trainer.train();
modelConfig.getTrain().setAlgorithm("SVM");
modelConfig.getTrain().getParams().put("Kernel", "rbf");
modelConfig.getTrain().getParams().put("Const", 0.1);
modelConfig.getTrain().getParams().put("Gamma", 1.0);
modelConfig.getVarSelect().setFilterNum(2);
SVMTrainer svm = new SVMTrainer(modelConfig, 1, false);
svm.setTrainSet(set);
svm.setValidSet(set);
svm.train();
models.add(trainer.getNetwork());
models.add(svm.getSVM());
}
Aggregations