use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class ValidationConductorTest method testPartershipModel.
// @Test
public void testPartershipModel() throws IOException {
ModelConfig modelConfig = CommonUtils.loadModelConfig("/Users/zhanhu/temp/partnership_varselect/ModelConfig.json", RawSourceData.SourceType.LOCAL);
List<ColumnConfig> columnConfigList = CommonUtils.loadColumnConfigList("/Users/zhanhu/temp/partnership_varselect/ColumnConfig.json", RawSourceData.SourceType.LOCAL);
List<Integer> columnIdList = new ArrayList<Integer>();
boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
for (ColumnConfig columnConfig : columnConfigList) {
if (CommonUtils.isGoodCandidate(columnConfig, hasCandidates)) {
columnIdList.add(columnConfig.getColumnNum());
}
}
TrainingDataSet trainingDataSet = new TrainingDataSet(columnIdList);
List<String> recordsList = IOUtils.readLines(new FileInputStream("/Users/zhanhu/temp/partnership_varselect/part-m-00479"));
for (String record : recordsList) {
addNormalizedRecordIntoTrainDataSet(modelConfig, columnConfigList, trainingDataSet, record);
}
Set<Integer> workingList = new HashSet<Integer>();
for (Integer columnId : trainingDataSet.getDataColumnIdList()) {
workingList.clear();
workingList.add(columnId);
ValidationConductor conductor = new ValidationConductor(modelConfig, columnConfigList, workingList, trainingDataSet);
double error = conductor.runValidate();
System.out.println("The error is - " + error + ", for columnId - " + columnId);
}
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class WrapperWorkerConductorTest method genTrainingDataSet.
public TrainingDataSet genTrainingDataSet(ModelConfig modelConfig, List<ColumnConfig> columnConfigList) throws IOException {
List<Integer> columnIdList = new ArrayList<Integer>();
boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
for (ColumnConfig columnConfig : columnConfigList) {
if (columnConfig.isCandidate(hasCandidates)) {
columnIdList.add(columnConfig.getColumnNum());
}
}
TrainingDataSet trainingDataSet = new TrainingDataSet(columnIdList);
List<String> recordsList = IOUtils.readLines(new FileInputStream("src/test/resources/example/cancer-judgement/DataStore/DataSet1/part-00"));
for (String record : recordsList) {
addRecordIntoTrainDataSet(modelConfig, columnConfigList, trainingDataSet, record);
}
return trainingDataSet;
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class NormalizerTest method getZScore1.
@Test
public void getZScore1() {
ColumnConfig config = new ColumnConfig();
config.setMean(2.0);
config.setStdDev(1.0);
config.setColumnType(ColumnType.N);
Assert.assertEquals(0.0, Normalizer.normalize(config, "2", 6.0).get(0));
Assert.assertEquals(0.0, Normalizer.normalize(config, "ABC", 0.1).get(0));
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class NormalizerTest method numericalNormalizeTest.
@Test
public void numericalNormalizeTest() {
// Input setting
ColumnConfig config = new ColumnConfig();
config.setMean(2.0);
config.setStdDev(1.0);
config.setColumnType(ColumnType.N);
ColumnBinning cbin = new ColumnBinning();
cbin.setBinCountWoe(Arrays.asList(new Double[] { 10.0, 11.0, 12.0, 13.0, 6.5 }));
cbin.setBinWeightedWoe(Arrays.asList(new Double[] { 20.0, 21.0, 22.0, 23.0, 16.5 }));
cbin.setBinBoundary(Arrays.asList(new Double[] { Double.NEGATIVE_INFINITY, 2.0, 4.0, 6.0 }));
cbin.setBinCountNeg(Arrays.asList(1, 2, 3, 4, 5));
cbin.setBinCountPos(Arrays.asList(5, 4, 3, 2, 1));
config.setColumnBinning(cbin);
// Test zscore normalization
Assert.assertEquals(Normalizer.normalize(config, "5.0", 4.0, NormType.ZSCALE).get(0), 3.0);
Assert.assertEquals(Normalizer.normalize(config, "5.0", null, NormType.ZSCALE).get(0), 3.0);
Assert.assertEquals(Normalizer.normalize(config, "wrong_format", 4.0, NormType.ZSCALE).get(0), 0.0);
Assert.assertEquals(Normalizer.normalize(config, null, 4.0, NormType.ZSCALE).get(0), 0.0);
// Test old zscore normalization
Assert.assertEquals(Normalizer.normalize(config, "5.0", 4.0, NormType.OLD_ZSCALE).get(0), 3.0);
Assert.assertEquals(Normalizer.normalize(config, "5.0", null, NormType.OLD_ZSCALE).get(0), 3.0);
Assert.assertEquals(Normalizer.normalize(config, "wrong_format", 4.0, NormType.OLD_ZSCALE).get(0), 0.0);
Assert.assertEquals(Normalizer.normalize(config, null, 4.0, NormType.OLD_ZSCALE).get(0), 0.0);
// Test woe normalization
Assert.assertEquals(Normalizer.normalize(config, "3.0", null, NormType.WEIGHT_WOE).get(0), 21.0);
Assert.assertEquals(Normalizer.normalize(config, "wrong_format", null, NormType.WEIGHT_WOE).get(0), 16.5);
Assert.assertEquals(Normalizer.normalize(config, null, null, NormType.WEIGHT_WOE).get(0), 16.5);
Assert.assertEquals(Normalizer.normalize(config, "3.0", null, NormType.WOE).get(0), 11.0);
Assert.assertEquals(Normalizer.normalize(config, "wrong_format", null, NormType.WOE).get(0), 6.5);
Assert.assertEquals(Normalizer.normalize(config, null, null, NormType.WOE).get(0), 6.5);
// Test hybrid normalization, for numerical use zscore.
Assert.assertEquals(Normalizer.normalize(config, "5.0", 4.0, NormType.HYBRID).get(0), 3.0);
Assert.assertEquals(Normalizer.normalize(config, "5.0", null, NormType.HYBRID).get(0), 3.0);
Assert.assertEquals(Normalizer.normalize(config, "wrong_format", 4.0, NormType.HYBRID).get(0), 0.0);
Assert.assertEquals(Normalizer.normalize(config, null, 4.0, NormType.HYBRID).get(0), 0.0);
// Currently WEIGHT_HYBRID and HYBRID act same for numerical value, both calculate zscore.
Assert.assertEquals(Normalizer.normalize(config, "5.0", 4.0, NormType.WEIGHT_HYBRID).get(0), 3.0);
Assert.assertEquals(Normalizer.normalize(config, "5.0", null, NormType.WEIGHT_HYBRID).get(0), 3.0);
Assert.assertEquals(Normalizer.normalize(config, "wrong_format", 4.0, NormType.WEIGHT_HYBRID).get(0), 0.0);
Assert.assertEquals(Normalizer.normalize(config, null, 4.0, NormType.WEIGHT_HYBRID).get(0), 0.0);
// Test woe zscore normalization
// Assert.assertEquals(Normalizer.normalize(config, "3.0", 10.0, NormType.WOE_ZSCORE), 0.2);
// Assert.assertEquals(Normalizer.normalize(config, "wrong_format", 12.0, NormType.WOE_ZSCORE), -1.6);
// Assert.assertEquals(Normalizer.normalize(config, null, 12.0, NormType.WOE_ZSCORE), -1.6);
//
// Assert.assertEquals(Normalizer.normalize(config, "3.0", 20.0, NormType.WEIGHT_WOE_ZSCORE), 0.2);
// Assert.assertEquals(Normalizer.normalize(config, "wrong_format", 22.0, NormType.WEIGHT_WOE_ZSCORE), -1.6);
// Assert.assertEquals(Normalizer.normalize(config, null, 22.0, NormType.WEIGHT_WOE_ZSCORE), -1.6);
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class NormalizerTest method getZScore3.
@Test
public void getZScore3() {
ColumnConfig config = new ColumnConfig();
config.setColumnType(ColumnType.C);
config.setMean(2.0);
config.setStdDev(1.0);
config.setBinCategory(Arrays.asList(new String[] { "1", "2", "3", "4", "ABC" }));
config.setBinPosCaseRate(Arrays.asList(new Double[] { 0.1, 2.0, 0.3, 0.1 }));
Assert.assertEquals(0.0, Normalizer.normalize(config, "2", 0.1).get(0));
// Assert.assertEquals(0.0, Normalizer.normalize(config, "5", 0.1);
}
Aggregations