Search in sources :

Example 31 with ModelConfig

use of ml.shifu.shifu.container.obj.ModelConfig in project shifu by ShifuML.

the class NNTrainerTest method testExceptionWhileSetupModel.

@Test(expectedExceptions = RuntimeException.class)
public void testExceptionWhileSetupModel() throws IOException {
    ModelConfig config = ModelConfig.createInitModelConfig(".", ALGORITHM.NN, ".", false);
    config.getTrain().getParams().put("Propagation", "Q");
    config.getTrain().getParams().put("NumHiddenLayers", 2);
    config.getTrain().getParams().put("LearningRate", 0.1);
    List<Integer> nodes = new ArrayList<Integer>();
    nodes.add(3);
    nodes.add(3);
    nodes.add(3);
    List<String> func = new ArrayList<String>();
    func.add("tanh");
    config.getTrain().getParams().put("NumHiddenNodes", nodes);
    config.getTrain().getParams().put("ActivationFunc", func);
    config.getTrain().setNumTrainEpochs(50);
    NNTrainer trainer = new NNTrainer(config, 0, false);
    try {
        trainer.setDataSet(xor_Trainset);
    } catch (IOException e) {
    }
    trainer.buildNetwork();
}
Also used : ModelConfig(ml.shifu.shifu.container.obj.ModelConfig) NNTrainer(ml.shifu.shifu.core.alg.NNTrainer) ArrayList(java.util.ArrayList) IOException(java.io.IOException) Test(org.testng.annotations.Test)

Example 32 with ModelConfig

use of ml.shifu.shifu.container.obj.ModelConfig in project shifu by ShifuML.

the class EqualPopulationBinningTest method testObjectSeri.

@Test
public void testObjectSeri() {
    Random rd = new Random(System.currentTimeMillis());
    EqualPopulationBinning binning = new EqualPopulationBinning(10);
    for (int i = 0; i < 10000; i++) {
        binning.addData(Double.toString(rd.nextGaussian() % 1000));
    }
    String binningStr = binning.objToString();
    String originalBinningData = binning.getDataBin().toString();
    ModelConfig modelConfig = new ModelConfig();
    modelConfig.getStats().setBinningMethod(BinningMethod.EqualPositive);
    ColumnConfig columnConfig = new ColumnConfig();
    columnConfig.setColumnType(ColumnType.N);
    AbstractBinning<?> otherBinning = AbstractBinning.constructBinningFromStr(modelConfig, columnConfig, binningStr);
    String newBinningData = otherBinning.getDataBin().toString();
    Assert.assertEquals(originalBinningData, newBinningData);
}
Also used : ModelConfig(ml.shifu.shifu.container.obj.ModelConfig) Random(java.util.Random) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) Test(org.testng.annotations.Test)

Example 33 with ModelConfig

use of ml.shifu.shifu.container.obj.ModelConfig in project shifu by ShifuML.

the class CommonUtils method loadModelConfig.

/**
 * Load model configuration from the path and the source type.
 *
 * @param path
 *            model file path
 * @param sourceType
 *            source type of model file
 * @return model config instance
 * @throws IOException
 *             if any IO exception in parsing json.
 *
 * @throws IllegalArgumentException
 *             if {@code path} is null or empty, if sourceType is null.
 */
public static ModelConfig loadModelConfig(String path, SourceType sourceType) throws IOException {
    ModelConfig modelConfig = loadJSON(path, sourceType, ModelConfig.class);
    if (StringUtils.isNotBlank(modelConfig.getTrain().getGridConfigFile())) {
        String gridConfigPath = modelConfig.getTrain().getGridConfigFile().trim();
        if (sourceType == SourceType.HDFS) {
            // gridsearch config file is uploaded to modelset path
            gridConfigPath = new PathFinder(modelConfig).getPathBySourceType(gridConfigPath.substring(gridConfigPath.lastIndexOf(File.separator) + 1), SourceType.HDFS);
        }
        // Only load file content. Grid search params parsing is done in {@link GridSearch} initialization.
        modelConfig.getTrain().setGridConfigFileContent(loadFileContent(gridConfigPath, sourceType));
    }
    return modelConfig;
}
Also used : ModelConfig(ml.shifu.shifu.container.obj.ModelConfig) PathFinder(ml.shifu.shifu.fs.PathFinder)

Aggregations

ModelConfig (ml.shifu.shifu.container.obj.ModelConfig)33 Test (org.testng.annotations.Test)18 ArrayList (java.util.ArrayList)9 File (java.io.File)8 ValidateResult (ml.shifu.shifu.container.meta.ValidateResult)8 ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)8 TrainingDataSet (ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet)4 IOException (java.io.IOException)3 EvalConfig (ml.shifu.shifu.container.obj.EvalConfig)3 NNTrainer (ml.shifu.shifu.core.alg.NNTrainer)3 BeforeClass (org.testng.annotations.BeforeClass)3 FileInputStream (java.io.FileInputStream)2 InvocationTargetException (java.lang.reflect.InvocationTargetException)2 HashSet (java.util.HashSet)2 List (java.util.List)2 Properties (java.util.Properties)2 RawSourceData (ml.shifu.shifu.container.obj.RawSourceData)2 PathFinder (ml.shifu.shifu.fs.PathFinder)2 MLDataPair (org.encog.ml.data.MLDataPair)2 BasicMLDataPair (org.encog.ml.data.basic.BasicMLDataPair)2