Search in sources :

Example 1 with NNTrainer

use of ml.shifu.shifu.core.alg.NNTrainer in project shifu by ShifuML.

the class TrainModelProcessor method runAkkaTrain.

/**
 * run training process with number of bags
 *
 * @param numBags
 *            number of bags, it decide how much trainer will start training
 * @throws IOException
 */
private void runAkkaTrain(int numBags) throws IOException {
    File models = new File("models");
    FileUtils.deleteDirectory(models);
    FileUtils.forceMkdir(models);
    trainers.clear();
    for (int i = 0; i < numBags; i++) {
        AbstractTrainer trainer;
        if (modelConfig.getAlgorithm().equalsIgnoreCase("NN")) {
            trainer = new NNTrainer(modelConfig, i, isDryTrain);
        } else if (modelConfig.getAlgorithm().equalsIgnoreCase("SVM")) {
            trainer = new SVMTrainer(this.modelConfig, i, isDryTrain);
        } else if (modelConfig.getAlgorithm().equalsIgnoreCase("LR")) {
            trainer = new LogisticRegressionTrainer(this.modelConfig, i, isDryTrain);
        } else {
            throw new ShifuException(ShifuErrorCode.ERROR_UNSUPPORT_ALG);
        }
        trainers.add(trainer);
    }
    List<Scanner> scanners = null;
    if (modelConfig.getAlgorithm().equalsIgnoreCase("DT")) {
        LOG.info("Raw Data: " + pathFinder.getNormalizedDataPath());
        try {
            scanners = ShifuFileUtils.getDataScanners(modelConfig.getDataSetRawPath(), modelConfig.getDataSet().getSource());
        } catch (IOException e) {
            throw new ShifuException(ShifuErrorCode.ERROR_INPUT_NOT_FOUND, e, pathFinder.getNormalizedDataPath());
        }
        if (CollectionUtils.isNotEmpty(scanners)) {
            AkkaSystemExecutor.getExecutor().submitDecisionTreeTrainJob(modelConfig, columnConfigList, scanners, trainers);
        }
    } else {
        LOG.info("Normalized Data: " + pathFinder.getNormalizedDataPath());
        try {
            scanners = ShifuFileUtils.getDataScanners(pathFinder.getNormalizedDataPath(), modelConfig.getDataSet().getSource());
        } catch (IOException e) {
            throw new ShifuException(ShifuErrorCode.ERROR_INPUT_NOT_FOUND, e, pathFinder.getNormalizedDataPath());
        }
        if (CollectionUtils.isNotEmpty(scanners)) {
            AkkaSystemExecutor.getExecutor().submitModelTrainJob(modelConfig, columnConfigList, scanners, trainers);
        }
    }
    // release
    closeScanners(scanners);
}
Also used : NNTrainer(ml.shifu.shifu.core.alg.NNTrainer) SVMTrainer(ml.shifu.shifu.core.alg.SVMTrainer) LogisticRegressionTrainer(ml.shifu.shifu.core.alg.LogisticRegressionTrainer) AbstractTrainer(ml.shifu.shifu.core.AbstractTrainer) ShifuException(ml.shifu.shifu.exception.ShifuException)

Example 2 with NNTrainer

use of ml.shifu.shifu.core.alg.NNTrainer in project shifu by ShifuML.

the class ValidationConductor method runValidate.

public double runValidate() {
    // 1. prepare training data
    MLDataSet trainingData = new BasicMLDataSet();
    MLDataSet testingData = new BasicMLDataSet();
    this.trainingDataSet.generateValidateData(this.workingColumnSet, this.modelConfig.getValidSetRate(), trainingData, testingData);
    // 2. build NNTrainer
    NNTrainer trainer = new NNTrainer(this.modelConfig, 1, false);
    trainer.setTrainSet(trainingData);
    trainer.setValidSet(testingData);
    trainer.disableModelPersistence();
    trainer.disableLogging();
    // 3. train and get validation error
    double validateError = Double.MAX_VALUE;
    try {
        validateError = trainer.train();
    } catch (IOException e) {
        // Ignore the exception when nn files
        validateError = trainer.getBaseMSE();
    }
    return validateError;
}
Also used : NNTrainer(ml.shifu.shifu.core.alg.NNTrainer) MLDataSet(org.encog.ml.data.MLDataSet) BasicMLDataSet(org.encog.ml.data.basic.BasicMLDataSet) BasicMLDataSet(org.encog.ml.data.basic.BasicMLDataSet) IOException(java.io.IOException)

Example 3 with NNTrainer

use of ml.shifu.shifu.core.alg.NNTrainer in project shifu by ShifuML.

the class NNTrainerTest method testXorOperation.

// @Test
public void testXorOperation() throws IOException {
    ModelConfig config = ModelConfig.createInitModelConfig(".", ALGORITHM.NN, ".", false);
    config.getTrain().setBaggingSampleRate(1.0);
    config.getTrain().setValidSetRate(0.1);
    config.getTrain().getParams().put("Propagation", "Q");
    config.getTrain().getParams().put("NumHiddenLayers", 1);
    config.getTrain().getParams().put("LearningRate", 1);
    List<Integer> nodes = new ArrayList<Integer>();
    nodes.add(5);
    List<String> func = new ArrayList<String>();
    func.add("tanh");
    config.getTrain().getParams().put("NumHiddenNodes", nodes);
    config.getTrain().getParams().put("ActivationFunc", func);
    config.getTrain().setNumTrainEpochs(100);
    NNTrainer trainer = new NNTrainer(config, 0, false);
    trainer.setTrainSet(xor_Trainset);
    trainer.setValidSet(xor_Validset);
    trainer.train();
    BasicNetwork bn = trainer.getNetwork();
    boolean[] cases = { true, false, false, true };
    int i = 0;
    for (MLDataPair data : xor_Validset) {
        double[] score = bn.compute(data.getInput()).getData();
        Assert.assertEquals(score[0] * 1000 < 500, cases[i]);
        i++;
    }
    Assert.assertEquals(bn.getLayerCount(), (Integer) (config.getTrain().getParams().get("NumHiddenLayers")) + 2);
}
Also used : BasicMLDataPair(org.encog.ml.data.basic.BasicMLDataPair) MLDataPair(org.encog.ml.data.MLDataPair) ModelConfig(ml.shifu.shifu.container.obj.ModelConfig) BasicNetwork(org.encog.neural.networks.BasicNetwork) NNTrainer(ml.shifu.shifu.core.alg.NNTrainer) ArrayList(java.util.ArrayList)

Example 4 with NNTrainer

use of ml.shifu.shifu.core.alg.NNTrainer in project shifu by ShifuML.

the class AbstractTrainerTest method testLoad1.

@Test
public void testLoad1() throws IOException {
    MLDataSet set = new BasicMLDataSet();
    ModelConfig modelConfig = CommonUtils.loadModelConfig("src/test/resources/example/cancer-judgement/ModelStore/ModelSet1/ModelConfig.json", SourceType.LOCAL);
    double[] input = new double[modelConfig.getVarSelectFilterNum()];
    for (int j = 0; j < 1000; j++) {
        for (int i = 0; i < modelConfig.getVarSelectFilterNum(); i++) {
            input[i] = random.nextDouble();
        }
        double[] ideal = new double[1];
        ideal[0] = random.nextInt(2);
        MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
        set.add(pair);
    }
    modelConfig.getTrain().setTrainOnDisk(false);
    AbstractTrainer trainer = new NNTrainer(modelConfig, 0, false);
    trainer.setDataSet(set);
    Assert.assertTrue(trainer.getTrainSet().getRecordCount() <= (1 - modelConfig.getValidSetRate()) * modelConfig.getBaggingSampleRate() * set.getRecordCount() * 1.05);
    Assert.assertTrue(trainer.getTrainSet().getRecordCount() >= (1 - modelConfig.getValidSetRate()) * modelConfig.getBaggingSampleRate() * set.getRecordCount() * 0.95);
    modelConfig.getTrain().setFixInitInput(true);
    trainer = new NNTrainer(modelConfig, 0, false);
    trainer.setDataSet(set);
    Assert.assertTrue(trainer.getTrainSet().getRecordCount() <= (1 - modelConfig.getValidSetRate()) * modelConfig.getBaggingSampleRate() * set.getRecordCount() * 1.05);
    Assert.assertTrue(trainer.getTrainSet().getRecordCount() >= (1 - modelConfig.getValidSetRate()) * modelConfig.getBaggingSampleRate() * set.getRecordCount() * 0.95);
    modelConfig.getTrain().setFixInitInput(false);
    modelConfig.getTrain().setBaggingWithReplacement(false);
    trainer = new NNTrainer(modelConfig, 0, false);
    trainer.setDataSet(set);
    Assert.assertTrue(trainer.getTrainSet().getRecordCount() <= (1 - modelConfig.getValidSetRate()) * modelConfig.getBaggingSampleRate() * set.getRecordCount() * 1.05);
    Assert.assertTrue(trainer.getTrainSet().getRecordCount() >= (1 - modelConfig.getValidSetRate()) * modelConfig.getBaggingSampleRate() * set.getRecordCount() * 0.95);
}
Also used : BasicMLDataPair(org.encog.ml.data.basic.BasicMLDataPair) MLDataPair(org.encog.ml.data.MLDataPair) ModelConfig(ml.shifu.shifu.container.obj.ModelConfig) NNTrainer(ml.shifu.shifu.core.alg.NNTrainer) MLDataSet(org.encog.ml.data.MLDataSet) BasicMLDataSet(org.encog.ml.data.basic.BasicMLDataSet) BasicMLDataPair(org.encog.ml.data.basic.BasicMLDataPair) BasicMLData(org.encog.ml.data.basic.BasicMLData) BasicMLDataSet(org.encog.ml.data.basic.BasicMLDataSet) Test(org.testng.annotations.Test)

Example 5 with NNTrainer

use of ml.shifu.shifu.core.alg.NNTrainer in project shifu by ShifuML.

the class TrainModelActorTest method testActor.

@Test
public void testActor() throws IOException, InterruptedException {
    File tmpDir = new File("./tmp");
    FileUtils.forceMkdir(tmpDir);
    // create normalize data
    actorSystem = ActorSystem.create("shifuActorSystem");
    ActorRef normalizeRef = actorSystem.actorOf(new Props(new UntypedActorFactory() {

        private static final long serialVersionUID = 6777309320338075269L;

        public UntypedActor create() throws IOException {
            return new NormalizeDataActor(modelConfig, columnConfigList, new AkkaExecStatus(true));
        }
    }), "normalize-calculator");
    List<Scanner> scanners = ShifuFileUtils.getDataScanners("src/test/resources/example/cancer-judgement/DataStore/DataSet1", SourceType.LOCAL);
    normalizeRef.tell(new AkkaActorInputMessage(scanners), normalizeRef);
    while (!normalizeRef.isTerminated()) {
        Thread.sleep(5000);
    }
    File outputFile = new File("./tmp/NormalizedData");
    Assert.assertTrue(outputFile.exists());
    // start to run trainer
    actorSystem = ActorSystem.create("shifuActorSystem");
    File models = new File("./models");
    FileUtils.forceMkdir(models);
    final List<AbstractTrainer> trainers = new ArrayList<AbstractTrainer>();
    for (int i = 0; i < 5; i++) {
        AbstractTrainer trainer;
        if (modelConfig.getAlgorithm().equalsIgnoreCase("NN")) {
            trainer = new NNTrainer(this.modelConfig, i, false);
        } else if (modelConfig.getAlgorithm().equalsIgnoreCase("SVM")) {
            trainer = new SVMTrainer(this.modelConfig, i, false);
        } else if (modelConfig.getAlgorithm().equalsIgnoreCase("LR")) {
            trainer = new LogisticRegressionTrainer(this.modelConfig, i, false);
        } else {
            throw new RuntimeException("unsupport algorithm");
        }
        trainers.add(trainer);
    }
    // train model
    ActorRef modelTrainRef = actorSystem.actorOf(new Props(new UntypedActorFactory() {

        private static final long serialVersionUID = 6777309320338075269L;

        public UntypedActor create() throws IOException {
            return new TrainModelActor(modelConfig, columnConfigList, new AkkaExecStatus(true), trainers);
        }
    }), "trainer");
    scanners = ShifuFileUtils.getDataScanners("./tmp/NormalizedData", SourceType.LOCAL);
    modelTrainRef.tell(new AkkaActorInputMessage(scanners), modelTrainRef);
    while (!modelTrainRef.isTerminated()) {
        Thread.sleep(5000);
    }
    for (Scanner scanner : scanners) {
        scanner.close();
    }
    File model0 = new File("./models/model0.nn");
    File model1 = new File("./models/model0.nn");
    File model2 = new File("./models/model0.nn");
    File model3 = new File("./models/model0.nn");
    File model4 = new File("./models/model0.nn");
    Assert.assertTrue(model0.exists());
    Assert.assertTrue(model1.exists());
    Assert.assertTrue(model2.exists());
    Assert.assertTrue(model3.exists());
    Assert.assertTrue(model4.exists());
    File modelsTemp = new File("./modelsTmp");
    FileUtils.deleteDirectory(modelsTemp);
    FileUtils.deleteDirectory(models);
    FileUtils.deleteDirectory(tmpDir);
}
Also used : Scanner(java.util.Scanner) AkkaActorInputMessage(ml.shifu.shifu.message.AkkaActorInputMessage) NNTrainer(ml.shifu.shifu.core.alg.NNTrainer) ArrayList(java.util.ArrayList) SVMTrainer(ml.shifu.shifu.core.alg.SVMTrainer) AbstractTrainer(ml.shifu.shifu.core.AbstractTrainer) LogisticRegressionTrainer(ml.shifu.shifu.core.alg.LogisticRegressionTrainer) File(java.io.File) Test(org.testng.annotations.Test)

Aggregations

NNTrainer (ml.shifu.shifu.core.alg.NNTrainer)7 ArrayList (java.util.ArrayList)4 ModelConfig (ml.shifu.shifu.container.obj.ModelConfig)3 SVMTrainer (ml.shifu.shifu.core.alg.SVMTrainer)3 MLDataPair (org.encog.ml.data.MLDataPair)3 BasicMLDataPair (org.encog.ml.data.basic.BasicMLDataPair)3 Test (org.testng.annotations.Test)3 IOException (java.io.IOException)2 AbstractTrainer (ml.shifu.shifu.core.AbstractTrainer)2 LogisticRegressionTrainer (ml.shifu.shifu.core.alg.LogisticRegressionTrainer)2 MLDataSet (org.encog.ml.data.MLDataSet)2 BasicMLData (org.encog.ml.data.basic.BasicMLData)2 BasicMLDataSet (org.encog.ml.data.basic.BasicMLDataSet)2 File (java.io.File)1 Scanner (java.util.Scanner)1 ShifuException (ml.shifu.shifu.exception.ShifuException)1 AkkaActorInputMessage (ml.shifu.shifu.message.AkkaActorInputMessage)1 BasicNetwork (org.encog.neural.networks.BasicNetwork)1 BeforeClass (org.testng.annotations.BeforeClass)1