use of ml.shifu.shifu.core.alg.NNTrainer in project shifu by ShifuML.
the class TrainModelProcessor method runAkkaTrain.
/**
* run training process with number of bags
*
* @param numBags
* number of bags, it decide how much trainer will start training
* @throws IOException
*/
private void runAkkaTrain(int numBags) throws IOException {
File models = new File("models");
FileUtils.deleteDirectory(models);
FileUtils.forceMkdir(models);
trainers.clear();
for (int i = 0; i < numBags; i++) {
AbstractTrainer trainer;
if (modelConfig.getAlgorithm().equalsIgnoreCase("NN")) {
trainer = new NNTrainer(modelConfig, i, isDryTrain);
} else if (modelConfig.getAlgorithm().equalsIgnoreCase("SVM")) {
trainer = new SVMTrainer(this.modelConfig, i, isDryTrain);
} else if (modelConfig.getAlgorithm().equalsIgnoreCase("LR")) {
trainer = new LogisticRegressionTrainer(this.modelConfig, i, isDryTrain);
} else {
throw new ShifuException(ShifuErrorCode.ERROR_UNSUPPORT_ALG);
}
trainers.add(trainer);
}
List<Scanner> scanners = null;
if (modelConfig.getAlgorithm().equalsIgnoreCase("DT")) {
LOG.info("Raw Data: " + pathFinder.getNormalizedDataPath());
try {
scanners = ShifuFileUtils.getDataScanners(modelConfig.getDataSetRawPath(), modelConfig.getDataSet().getSource());
} catch (IOException e) {
throw new ShifuException(ShifuErrorCode.ERROR_INPUT_NOT_FOUND, e, pathFinder.getNormalizedDataPath());
}
if (CollectionUtils.isNotEmpty(scanners)) {
AkkaSystemExecutor.getExecutor().submitDecisionTreeTrainJob(modelConfig, columnConfigList, scanners, trainers);
}
} else {
LOG.info("Normalized Data: " + pathFinder.getNormalizedDataPath());
try {
scanners = ShifuFileUtils.getDataScanners(pathFinder.getNormalizedDataPath(), modelConfig.getDataSet().getSource());
} catch (IOException e) {
throw new ShifuException(ShifuErrorCode.ERROR_INPUT_NOT_FOUND, e, pathFinder.getNormalizedDataPath());
}
if (CollectionUtils.isNotEmpty(scanners)) {
AkkaSystemExecutor.getExecutor().submitModelTrainJob(modelConfig, columnConfigList, scanners, trainers);
}
}
// release
closeScanners(scanners);
}
use of ml.shifu.shifu.core.alg.NNTrainer in project shifu by ShifuML.
the class ValidationConductor method runValidate.
public double runValidate() {
// 1. prepare training data
MLDataSet trainingData = new BasicMLDataSet();
MLDataSet testingData = new BasicMLDataSet();
this.trainingDataSet.generateValidateData(this.workingColumnSet, this.modelConfig.getValidSetRate(), trainingData, testingData);
// 2. build NNTrainer
NNTrainer trainer = new NNTrainer(this.modelConfig, 1, false);
trainer.setTrainSet(trainingData);
trainer.setValidSet(testingData);
trainer.disableModelPersistence();
trainer.disableLogging();
// 3. train and get validation error
double validateError = Double.MAX_VALUE;
try {
validateError = trainer.train();
} catch (IOException e) {
// Ignore the exception when nn files
validateError = trainer.getBaseMSE();
}
return validateError;
}
use of ml.shifu.shifu.core.alg.NNTrainer in project shifu by ShifuML.
the class NNTrainerTest method testXorOperation.
// @Test
public void testXorOperation() throws IOException {
ModelConfig config = ModelConfig.createInitModelConfig(".", ALGORITHM.NN, ".", false);
config.getTrain().setBaggingSampleRate(1.0);
config.getTrain().setValidSetRate(0.1);
config.getTrain().getParams().put("Propagation", "Q");
config.getTrain().getParams().put("NumHiddenLayers", 1);
config.getTrain().getParams().put("LearningRate", 1);
List<Integer> nodes = new ArrayList<Integer>();
nodes.add(5);
List<String> func = new ArrayList<String>();
func.add("tanh");
config.getTrain().getParams().put("NumHiddenNodes", nodes);
config.getTrain().getParams().put("ActivationFunc", func);
config.getTrain().setNumTrainEpochs(100);
NNTrainer trainer = new NNTrainer(config, 0, false);
trainer.setTrainSet(xor_Trainset);
trainer.setValidSet(xor_Validset);
trainer.train();
BasicNetwork bn = trainer.getNetwork();
boolean[] cases = { true, false, false, true };
int i = 0;
for (MLDataPair data : xor_Validset) {
double[] score = bn.compute(data.getInput()).getData();
Assert.assertEquals(score[0] * 1000 < 500, cases[i]);
i++;
}
Assert.assertEquals(bn.getLayerCount(), (Integer) (config.getTrain().getParams().get("NumHiddenLayers")) + 2);
}
use of ml.shifu.shifu.core.alg.NNTrainer in project shifu by ShifuML.
the class AbstractTrainerTest method testLoad1.
@Test
public void testLoad1() throws IOException {
MLDataSet set = new BasicMLDataSet();
ModelConfig modelConfig = CommonUtils.loadModelConfig("src/test/resources/example/cancer-judgement/ModelStore/ModelSet1/ModelConfig.json", SourceType.LOCAL);
double[] input = new double[modelConfig.getVarSelectFilterNum()];
for (int j = 0; j < 1000; j++) {
for (int i = 0; i < modelConfig.getVarSelectFilterNum(); i++) {
input[i] = random.nextDouble();
}
double[] ideal = new double[1];
ideal[0] = random.nextInt(2);
MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
set.add(pair);
}
modelConfig.getTrain().setTrainOnDisk(false);
AbstractTrainer trainer = new NNTrainer(modelConfig, 0, false);
trainer.setDataSet(set);
Assert.assertTrue(trainer.getTrainSet().getRecordCount() <= (1 - modelConfig.getValidSetRate()) * modelConfig.getBaggingSampleRate() * set.getRecordCount() * 1.05);
Assert.assertTrue(trainer.getTrainSet().getRecordCount() >= (1 - modelConfig.getValidSetRate()) * modelConfig.getBaggingSampleRate() * set.getRecordCount() * 0.95);
modelConfig.getTrain().setFixInitInput(true);
trainer = new NNTrainer(modelConfig, 0, false);
trainer.setDataSet(set);
Assert.assertTrue(trainer.getTrainSet().getRecordCount() <= (1 - modelConfig.getValidSetRate()) * modelConfig.getBaggingSampleRate() * set.getRecordCount() * 1.05);
Assert.assertTrue(trainer.getTrainSet().getRecordCount() >= (1 - modelConfig.getValidSetRate()) * modelConfig.getBaggingSampleRate() * set.getRecordCount() * 0.95);
modelConfig.getTrain().setFixInitInput(false);
modelConfig.getTrain().setBaggingWithReplacement(false);
trainer = new NNTrainer(modelConfig, 0, false);
trainer.setDataSet(set);
Assert.assertTrue(trainer.getTrainSet().getRecordCount() <= (1 - modelConfig.getValidSetRate()) * modelConfig.getBaggingSampleRate() * set.getRecordCount() * 1.05);
Assert.assertTrue(trainer.getTrainSet().getRecordCount() >= (1 - modelConfig.getValidSetRate()) * modelConfig.getBaggingSampleRate() * set.getRecordCount() * 0.95);
}
use of ml.shifu.shifu.core.alg.NNTrainer in project shifu by ShifuML.
the class TrainModelActorTest method testActor.
@Test
public void testActor() throws IOException, InterruptedException {
File tmpDir = new File("./tmp");
FileUtils.forceMkdir(tmpDir);
// create normalize data
actorSystem = ActorSystem.create("shifuActorSystem");
ActorRef normalizeRef = actorSystem.actorOf(new Props(new UntypedActorFactory() {
private static final long serialVersionUID = 6777309320338075269L;
public UntypedActor create() throws IOException {
return new NormalizeDataActor(modelConfig, columnConfigList, new AkkaExecStatus(true));
}
}), "normalize-calculator");
List<Scanner> scanners = ShifuFileUtils.getDataScanners("src/test/resources/example/cancer-judgement/DataStore/DataSet1", SourceType.LOCAL);
normalizeRef.tell(new AkkaActorInputMessage(scanners), normalizeRef);
while (!normalizeRef.isTerminated()) {
Thread.sleep(5000);
}
File outputFile = new File("./tmp/NormalizedData");
Assert.assertTrue(outputFile.exists());
// start to run trainer
actorSystem = ActorSystem.create("shifuActorSystem");
File models = new File("./models");
FileUtils.forceMkdir(models);
final List<AbstractTrainer> trainers = new ArrayList<AbstractTrainer>();
for (int i = 0; i < 5; i++) {
AbstractTrainer trainer;
if (modelConfig.getAlgorithm().equalsIgnoreCase("NN")) {
trainer = new NNTrainer(this.modelConfig, i, false);
} else if (modelConfig.getAlgorithm().equalsIgnoreCase("SVM")) {
trainer = new SVMTrainer(this.modelConfig, i, false);
} else if (modelConfig.getAlgorithm().equalsIgnoreCase("LR")) {
trainer = new LogisticRegressionTrainer(this.modelConfig, i, false);
} else {
throw new RuntimeException("unsupport algorithm");
}
trainers.add(trainer);
}
// train model
ActorRef modelTrainRef = actorSystem.actorOf(new Props(new UntypedActorFactory() {
private static final long serialVersionUID = 6777309320338075269L;
public UntypedActor create() throws IOException {
return new TrainModelActor(modelConfig, columnConfigList, new AkkaExecStatus(true), trainers);
}
}), "trainer");
scanners = ShifuFileUtils.getDataScanners("./tmp/NormalizedData", SourceType.LOCAL);
modelTrainRef.tell(new AkkaActorInputMessage(scanners), modelTrainRef);
while (!modelTrainRef.isTerminated()) {
Thread.sleep(5000);
}
for (Scanner scanner : scanners) {
scanner.close();
}
File model0 = new File("./models/model0.nn");
File model1 = new File("./models/model0.nn");
File model2 = new File("./models/model0.nn");
File model3 = new File("./models/model0.nn");
File model4 = new File("./models/model0.nn");
Assert.assertTrue(model0.exists());
Assert.assertTrue(model1.exists());
Assert.assertTrue(model2.exists());
Assert.assertTrue(model3.exists());
Assert.assertTrue(model4.exists());
File modelsTemp = new File("./modelsTmp");
FileUtils.deleteDirectory(modelsTemp);
FileUtils.deleteDirectory(models);
FileUtils.deleteDirectory(tmpDir);
}
Aggregations