use of ml.shifu.shifu.container.obj.ModelConfig in project shifu by ShifuML.
the class ModelInspectorTest method testValidateStats.
@Test
public void testValidateStats() throws Exception {
ModelConfig config = CommonUtils.loadModelConfig();
ValidateResult result = instance.probe(config, ModelStep.STATS);
Assert.assertTrue(result.getStatus());
}
use of ml.shifu.shifu.container.obj.ModelConfig in project shifu by ShifuML.
the class ModelInspectorTest method testValidatePostTrain.
@Test
public void testValidatePostTrain() throws Exception {
ModelConfig config = CommonUtils.loadModelConfig();
ValidateResult result = instance.probe(config, ModelStep.POSTTRAIN);
Assert.assertTrue(result.getStatus());
}
use of ml.shifu.shifu.container.obj.ModelConfig in project shifu by ShifuML.
the class NNTrainerTest method testXorOperation.
// @Test
public void testXorOperation() throws IOException {
ModelConfig config = ModelConfig.createInitModelConfig(".", ALGORITHM.NN, ".", false);
config.getTrain().setBaggingSampleRate(1.0);
config.getTrain().setValidSetRate(0.1);
config.getTrain().getParams().put("Propagation", "Q");
config.getTrain().getParams().put("NumHiddenLayers", 1);
config.getTrain().getParams().put("LearningRate", 1);
List<Integer> nodes = new ArrayList<Integer>();
nodes.add(5);
List<String> func = new ArrayList<String>();
func.add("tanh");
config.getTrain().getParams().put("NumHiddenNodes", nodes);
config.getTrain().getParams().put("ActivationFunc", func);
config.getTrain().setNumTrainEpochs(100);
NNTrainer trainer = new NNTrainer(config, 0, false);
trainer.setTrainSet(xor_Trainset);
trainer.setValidSet(xor_Validset);
trainer.train();
BasicNetwork bn = trainer.getNetwork();
boolean[] cases = { true, false, false, true };
int i = 0;
for (MLDataPair data : xor_Validset) {
double[] score = bn.compute(data.getInput()).getData();
Assert.assertEquals(score[0] * 1000 < 500, cases[i]);
i++;
}
Assert.assertEquals(bn.getLayerCount(), (Integer) (config.getTrain().getParams().get("NumHiddenLayers")) + 2);
}
use of ml.shifu.shifu.container.obj.ModelConfig in project shifu by ShifuML.
the class SVMTrainerTest method setUp.
// MLDataSet dataSet;
// MLDataSet trainSet;
// MLDataSet validSet, testSet;
// Random random;
@BeforeClass
public void setUp() throws IOException {
// .createInitModelConfig("./", "./");
config = new ModelConfig();
config.getTrain().setAlgorithm("SVM");
config.getDataSet().setSource(SourceType.LOCAL);
config.getVarSelect().setFilterNum(2);
config.getDataSet().setDataDelimiter(",");
config.getDataSet().setSource(SourceType.HDFS);
config.getTrain().setParams(new HashMap<String, Object>());
config.getTrain().getParams().put("Const", 1.1);
config.getTrain().getParams().put("Gamma", 0.95);
config.getTrain().getParams().put("Kernel", "rbf");
config.getTrain().setBaggingSampleRate(1.0);
config.getTrain().setBaggingWithReplacement(false);
trainer = new SVMTrainer(config, 0, false);
trainer.setTrainSet(xor_Trainset);
trainer.setValidSet(xor_Validset);
}
use of ml.shifu.shifu.container.obj.ModelConfig in project shifu by ShifuML.
the class AbstractTrainerTest method testLoad1.
@Test
public void testLoad1() throws IOException {
MLDataSet set = new BasicMLDataSet();
ModelConfig modelConfig = CommonUtils.loadModelConfig("src/test/resources/example/cancer-judgement/ModelStore/ModelSet1/ModelConfig.json", SourceType.LOCAL);
double[] input = new double[modelConfig.getVarSelectFilterNum()];
for (int j = 0; j < 1000; j++) {
for (int i = 0; i < modelConfig.getVarSelectFilterNum(); i++) {
input[i] = random.nextDouble();
}
double[] ideal = new double[1];
ideal[0] = random.nextInt(2);
MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
set.add(pair);
}
modelConfig.getTrain().setTrainOnDisk(false);
AbstractTrainer trainer = new NNTrainer(modelConfig, 0, false);
trainer.setDataSet(set);
Assert.assertTrue(trainer.getTrainSet().getRecordCount() <= (1 - modelConfig.getValidSetRate()) * modelConfig.getBaggingSampleRate() * set.getRecordCount() * 1.05);
Assert.assertTrue(trainer.getTrainSet().getRecordCount() >= (1 - modelConfig.getValidSetRate()) * modelConfig.getBaggingSampleRate() * set.getRecordCount() * 0.95);
modelConfig.getTrain().setFixInitInput(true);
trainer = new NNTrainer(modelConfig, 0, false);
trainer.setDataSet(set);
Assert.assertTrue(trainer.getTrainSet().getRecordCount() <= (1 - modelConfig.getValidSetRate()) * modelConfig.getBaggingSampleRate() * set.getRecordCount() * 1.05);
Assert.assertTrue(trainer.getTrainSet().getRecordCount() >= (1 - modelConfig.getValidSetRate()) * modelConfig.getBaggingSampleRate() * set.getRecordCount() * 0.95);
modelConfig.getTrain().setFixInitInput(false);
modelConfig.getTrain().setBaggingWithReplacement(false);
trainer = new NNTrainer(modelConfig, 0, false);
trainer.setDataSet(set);
Assert.assertTrue(trainer.getTrainSet().getRecordCount() <= (1 - modelConfig.getValidSetRate()) * modelConfig.getBaggingSampleRate() * set.getRecordCount() * 1.05);
Assert.assertTrue(trainer.getTrainSet().getRecordCount() >= (1 - modelConfig.getValidSetRate()) * modelConfig.getBaggingSampleRate() * set.getRecordCount() * 0.95);
}
Aggregations