Search in sources :

Example 56 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class ValidationConductorTest method testPartershipModel.

// @Test
public void testPartershipModel() throws IOException {
    ModelConfig modelConfig = CommonUtils.loadModelConfig("/Users/zhanhu/temp/partnership_varselect/ModelConfig.json", RawSourceData.SourceType.LOCAL);
    List<ColumnConfig> columnConfigList = CommonUtils.loadColumnConfigList("/Users/zhanhu/temp/partnership_varselect/ColumnConfig.json", RawSourceData.SourceType.LOCAL);
    List<Integer> columnIdList = new ArrayList<Integer>();
    boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
    for (ColumnConfig columnConfig : columnConfigList) {
        if (CommonUtils.isGoodCandidate(columnConfig, hasCandidates)) {
            columnIdList.add(columnConfig.getColumnNum());
        }
    }
    TrainingDataSet trainingDataSet = new TrainingDataSet(columnIdList);
    List<String> recordsList = IOUtils.readLines(new FileInputStream("/Users/zhanhu/temp/partnership_varselect/part-m-00479"));
    for (String record : recordsList) {
        addNormalizedRecordIntoTrainDataSet(modelConfig, columnConfigList, trainingDataSet, record);
    }
    Set<Integer> workingList = new HashSet<Integer>();
    for (Integer columnId : trainingDataSet.getDataColumnIdList()) {
        workingList.clear();
        workingList.add(columnId);
        ValidationConductor conductor = new ValidationConductor(modelConfig, columnConfigList, workingList, trainingDataSet);
        double error = conductor.runValidate();
        System.out.println("The error is - " + error + ", for columnId - " + columnId);
    }
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) ArrayList(java.util.ArrayList) FileInputStream(java.io.FileInputStream) ModelConfig(ml.shifu.shifu.container.obj.ModelConfig) TrainingDataSet(ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet) HashSet(java.util.HashSet)

Example 57 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class WrapperWorkerConductorTest method genTrainingDataSet.

public TrainingDataSet genTrainingDataSet(ModelConfig modelConfig, List<ColumnConfig> columnConfigList) throws IOException {
    List<Integer> columnIdList = new ArrayList<Integer>();
    boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
    for (ColumnConfig columnConfig : columnConfigList) {
        if (columnConfig.isCandidate(hasCandidates)) {
            columnIdList.add(columnConfig.getColumnNum());
        }
    }
    TrainingDataSet trainingDataSet = new TrainingDataSet(columnIdList);
    List<String> recordsList = IOUtils.readLines(new FileInputStream("src/test/resources/example/cancer-judgement/DataStore/DataSet1/part-00"));
    for (String record : recordsList) {
        addRecordIntoTrainDataSet(modelConfig, columnConfigList, trainingDataSet, record);
    }
    return trainingDataSet;
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) ArrayList(java.util.ArrayList) TrainingDataSet(ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet) FileInputStream(java.io.FileInputStream)

Example 58 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class NormalizerTest method getZScore1.

@Test
public void getZScore1() {
    ColumnConfig config = new ColumnConfig();
    config.setMean(2.0);
    config.setStdDev(1.0);
    config.setColumnType(ColumnType.N);
    Assert.assertEquals(0.0, Normalizer.normalize(config, "2", 6.0).get(0));
    Assert.assertEquals(0.0, Normalizer.normalize(config, "ABC", 0.1).get(0));
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) Test(org.testng.annotations.Test)

Example 59 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class NormalizerTest method numericalNormalizeTest.

@Test
public void numericalNormalizeTest() {
    // Input setting
    ColumnConfig config = new ColumnConfig();
    config.setMean(2.0);
    config.setStdDev(1.0);
    config.setColumnType(ColumnType.N);
    ColumnBinning cbin = new ColumnBinning();
    cbin.setBinCountWoe(Arrays.asList(new Double[] { 10.0, 11.0, 12.0, 13.0, 6.5 }));
    cbin.setBinWeightedWoe(Arrays.asList(new Double[] { 20.0, 21.0, 22.0, 23.0, 16.5 }));
    cbin.setBinBoundary(Arrays.asList(new Double[] { Double.NEGATIVE_INFINITY, 2.0, 4.0, 6.0 }));
    cbin.setBinCountNeg(Arrays.asList(1, 2, 3, 4, 5));
    cbin.setBinCountPos(Arrays.asList(5, 4, 3, 2, 1));
    config.setColumnBinning(cbin);
    // Test zscore normalization
    Assert.assertEquals(Normalizer.normalize(config, "5.0", 4.0, NormType.ZSCALE).get(0), 3.0);
    Assert.assertEquals(Normalizer.normalize(config, "5.0", null, NormType.ZSCALE).get(0), 3.0);
    Assert.assertEquals(Normalizer.normalize(config, "wrong_format", 4.0, NormType.ZSCALE).get(0), 0.0);
    Assert.assertEquals(Normalizer.normalize(config, null, 4.0, NormType.ZSCALE).get(0), 0.0);
    // Test old zscore normalization
    Assert.assertEquals(Normalizer.normalize(config, "5.0", 4.0, NormType.OLD_ZSCALE).get(0), 3.0);
    Assert.assertEquals(Normalizer.normalize(config, "5.0", null, NormType.OLD_ZSCALE).get(0), 3.0);
    Assert.assertEquals(Normalizer.normalize(config, "wrong_format", 4.0, NormType.OLD_ZSCALE).get(0), 0.0);
    Assert.assertEquals(Normalizer.normalize(config, null, 4.0, NormType.OLD_ZSCALE).get(0), 0.0);
    // Test woe normalization
    Assert.assertEquals(Normalizer.normalize(config, "3.0", null, NormType.WEIGHT_WOE).get(0), 21.0);
    Assert.assertEquals(Normalizer.normalize(config, "wrong_format", null, NormType.WEIGHT_WOE).get(0), 16.5);
    Assert.assertEquals(Normalizer.normalize(config, null, null, NormType.WEIGHT_WOE).get(0), 16.5);
    Assert.assertEquals(Normalizer.normalize(config, "3.0", null, NormType.WOE).get(0), 11.0);
    Assert.assertEquals(Normalizer.normalize(config, "wrong_format", null, NormType.WOE).get(0), 6.5);
    Assert.assertEquals(Normalizer.normalize(config, null, null, NormType.WOE).get(0), 6.5);
    // Test hybrid normalization, for numerical use zscore.
    Assert.assertEquals(Normalizer.normalize(config, "5.0", 4.0, NormType.HYBRID).get(0), 3.0);
    Assert.assertEquals(Normalizer.normalize(config, "5.0", null, NormType.HYBRID).get(0), 3.0);
    Assert.assertEquals(Normalizer.normalize(config, "wrong_format", 4.0, NormType.HYBRID).get(0), 0.0);
    Assert.assertEquals(Normalizer.normalize(config, null, 4.0, NormType.HYBRID).get(0), 0.0);
    // Currently WEIGHT_HYBRID and HYBRID act same for numerical value, both calculate zscore.
    Assert.assertEquals(Normalizer.normalize(config, "5.0", 4.0, NormType.WEIGHT_HYBRID).get(0), 3.0);
    Assert.assertEquals(Normalizer.normalize(config, "5.0", null, NormType.WEIGHT_HYBRID).get(0), 3.0);
    Assert.assertEquals(Normalizer.normalize(config, "wrong_format", 4.0, NormType.WEIGHT_HYBRID).get(0), 0.0);
    Assert.assertEquals(Normalizer.normalize(config, null, 4.0, NormType.WEIGHT_HYBRID).get(0), 0.0);
// Test woe zscore normalization
// Assert.assertEquals(Normalizer.normalize(config, "3.0", 10.0, NormType.WOE_ZSCORE), 0.2);
// Assert.assertEquals(Normalizer.normalize(config, "wrong_format", 12.0, NormType.WOE_ZSCORE), -1.6);
// Assert.assertEquals(Normalizer.normalize(config, null, 12.0, NormType.WOE_ZSCORE), -1.6);
// 
// Assert.assertEquals(Normalizer.normalize(config, "3.0", 20.0, NormType.WEIGHT_WOE_ZSCORE), 0.2);
// Assert.assertEquals(Normalizer.normalize(config, "wrong_format", 22.0, NormType.WEIGHT_WOE_ZSCORE), -1.6);
// Assert.assertEquals(Normalizer.normalize(config, null, 22.0, NormType.WEIGHT_WOE_ZSCORE), -1.6);
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) ColumnBinning(ml.shifu.shifu.container.obj.ColumnBinning) Test(org.testng.annotations.Test)

Example 60 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class NormalizerTest method getZScore3.

@Test
public void getZScore3() {
    ColumnConfig config = new ColumnConfig();
    config.setColumnType(ColumnType.C);
    config.setMean(2.0);
    config.setStdDev(1.0);
    config.setBinCategory(Arrays.asList(new String[] { "1", "2", "3", "4", "ABC" }));
    config.setBinPosCaseRate(Arrays.asList(new Double[] { 0.1, 2.0, 0.3, 0.1 }));
    Assert.assertEquals(0.0, Normalizer.normalize(config, "2", 0.1).get(0));
// Assert.assertEquals(0.0, Normalizer.normalize(config, "5", 0.1);
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) Test(org.testng.annotations.Test)

Aggregations

ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)131 ArrayList (java.util.ArrayList)36 Test (org.testng.annotations.Test)17 IOException (java.io.IOException)16 HashMap (java.util.HashMap)12 Tuple (org.apache.pig.data.Tuple)10 File (java.io.File)8 NSColumn (ml.shifu.shifu.column.NSColumn)8 ModelConfig (ml.shifu.shifu.container.obj.ModelConfig)8 ShifuException (ml.shifu.shifu.exception.ShifuException)8 Path (org.apache.hadoop.fs.Path)8 List (java.util.List)7 Scanner (java.util.Scanner)7 DataBag (org.apache.pig.data.DataBag)7 SourceType (ml.shifu.shifu.container.obj.RawSourceData.SourceType)5 BasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)5 TrainingDataSet (ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet)5 BasicMLData (org.encog.ml.data.basic.BasicMLData)5 BufferedWriter (java.io.BufferedWriter)3 FileInputStream (java.io.FileInputStream)3