use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class NormalizerTest method categoricalNormalizeTest.
@Test
public void categoricalNormalizeTest() {
// Input setting
ColumnConfig config = new ColumnConfig();
config.setMean(0.2);
config.setStdDev(1.0);
config.setColumnType(ColumnType.C);
ColumnBinning cbin = new ColumnBinning();
cbin.setBinCountWoe(Arrays.asList(new Double[] { 10.0, 11.0, 12.0, 13.0, 6.5 }));
cbin.setBinWeightedWoe(Arrays.asList(new Double[] { 20.0, 21.0, 22.0, 23.0, 16.5 }));
cbin.setBinCategory(Arrays.asList(new String[] { "a", "b", "c", "d" }));
cbin.setBinPosRate(Arrays.asList(new Double[] { 0.2, 0.4, 0.8, 1.0 }));
cbin.setBinCountNeg(Arrays.asList(1, 2, 3, 4, 5));
cbin.setBinCountPos(Arrays.asList(5, 4, 3, 2, 1));
config.setColumnBinning(cbin);
// Test zscore normalization
Assert.assertEquals(Normalizer.normalize(config, "b", 4.0, NormType.ZSCALE).get(0), 0.2);
Assert.assertEquals(Normalizer.normalize(config, "b", null, NormType.ZSCALE).get(0), 0.2);
Assert.assertEquals(Normalizer.normalize(config, "wrong_format", 4.0, NormType.ZSCALE, CategoryMissingNormType.MEAN).get(0), 0.0);
Assert.assertEquals(Normalizer.normalize(config, null, 4.0, NormType.ZSCALE, CategoryMissingNormType.MEAN).get(0), 0.0);
// Test old zscore normalization
Assert.assertEquals(Normalizer.normalize(config, "b", 4.0, NormType.OLD_ZSCALE).get(0), 0.4);
Assert.assertEquals(Normalizer.normalize(config, "b", null, NormType.OLD_ZSCALE).get(0), 0.4);
Assert.assertEquals(Normalizer.normalize(config, "wrong_format", 4.0, NormType.OLD_ZSCALE, CategoryMissingNormType.MEAN).get(0), 0.2);
Assert.assertEquals(Normalizer.normalize(config, null, 4.0, NormType.OLD_ZSCALE, CategoryMissingNormType.MEAN).get(0), 0.2);
// Test woe normalization
Assert.assertEquals(Normalizer.normalize(config, "c", null, NormType.WEIGHT_WOE).get(0), 22.0);
Assert.assertEquals(Normalizer.normalize(config, "wrong_format", null, NormType.WEIGHT_WOE).get(0), 16.5);
Assert.assertEquals(Normalizer.normalize(config, null, null, NormType.WEIGHT_WOE).get(0), 16.5);
Assert.assertEquals(Normalizer.normalize(config, "c", null, NormType.WOE).get(0), 12.0);
Assert.assertEquals(Normalizer.normalize(config, "wrong_format", null, NormType.WOE).get(0), 6.5);
Assert.assertEquals(Normalizer.normalize(config, null, null, NormType.WOE).get(0), 6.5);
// Test hybrid normalization, for categorical value use [weight]woe.
Assert.assertEquals(Normalizer.normalize(config, "a", null, NormType.HYBRID).get(0), 10.0);
Assert.assertEquals(Normalizer.normalize(config, "wrong_format", null, NormType.HYBRID).get(0), 6.5);
Assert.assertEquals(Normalizer.normalize(config, null, null, NormType.HYBRID).get(0), 6.5);
Assert.assertEquals(Normalizer.normalize(config, "a", null, NormType.WEIGHT_HYBRID).get(0), 20.0);
Assert.assertEquals(Normalizer.normalize(config, "wrong_format", null, NormType.WEIGHT_HYBRID).get(0), 16.5);
Assert.assertEquals(Normalizer.normalize(config, null, null, NormType.WEIGHT_HYBRID).get(0), 16.5);
// Test woe zscore normalization
// Assert.assertEquals(Normalizer.normalize(config, "b", 12.0, NormType.WOE_ZSCORE), 0.2);
// Assert.assertEquals(Normalizer.normalize(config, "wrong_format", 13.0, NormType.WOE_ZSCORE), -1.6);
// Assert.assertEquals(Normalizer.normalize(config, null, 13.0, NormType.WOE_ZSCORE), -1.6);
//
// Assert.assertEquals(Normalizer.normalize(config, "b", 22.0, NormType.WEIGHT_WOE_ZSCORE), 0.2);
// Assert.assertEquals(Normalizer.normalize(config, "wrong_format", 23.0, NormType.WEIGHT_WOE_ZSCORE), -1.6);
// Assert.assertEquals(Normalizer.normalize(config, null, 23.0, NormType.WEIGHT_WOE_ZSCORE), -1.6);
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class NormalizerTest method getZScore2.
@Test
public void getZScore2() {
ColumnConfig config = new ColumnConfig();
config.setMean(2.0);
config.setStdDev(1.0);
config.setColumnType(ColumnType.N);
Assert.assertEquals(-4.0, Normalizer.normalize(config, "-3", null).get(0));
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class ScorerTest method scoreTest.
// @Test
public void scoreTest() {
List<ColumnConfig> list = new ArrayList<ColumnConfig>();
ColumnConfig col = new ColumnConfig();
col.setColumnType(ColumnType.N);
col.setColumnName("A");
col.setColumnNum(0);
col.setFinalSelect(true);
list.add(col);
col = new ColumnConfig();
col.setColumnType(ColumnType.N);
col.setColumnName("B");
col.setColumnNum(1);
col.setFinalSelect(true);
list.add(col);
Scorer s = new Scorer(models, list, "NN", modelConfig);
double[] input = { 0., 0. };
double[] ideal = { 1. };
MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
ScoreObject o = s.score(pair, null);
List<Double> scores = o.getScores();
Assert.assertTrue(scores.get(0) > 400);
Assert.assertTrue(scores.get(1) == 1000);
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class ScorerTest method scoreModelsException.
// @Test
public void scoreModelsException() {
List<ColumnConfig> list = new ArrayList<ColumnConfig>();
ColumnConfig col = new ColumnConfig();
col.setColumnType(ColumnType.N);
col.setColumnName("A");
col.setColumnNum(0);
col.setFinalSelect(true);
list.add(col);
col = new ColumnConfig();
col.setColumnType(ColumnType.N);
col.setColumnName("B");
col.setColumnNum(1);
col.setFinalSelect(true);
list.add(col);
Scorer s = new Scorer(models, list, "NN", modelConfig);
double[] input = { 0., 0., 3. };
double[] ideal = { 1. };
MLDataPair pair = new BasicMLDataPair(new BasicMLData(input), new BasicMLData(ideal));
Assert.assertEquals(s.score(pair, null).getScores().size(), 0);
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class CalculateStatsActor method onReceive.
/* (non-Javadoc)
* @see akka.actor.UntypedActor#onReceive(java.lang.Object)
*/
@Override
public void onReceive(Object message) throws Exception {
if (message instanceof AkkaActorInputMessage) {
resultCnt = 0;
AkkaActorInputMessage msg = (AkkaActorInputMessage) message;
List<Scanner> scanners = msg.getScanners();
log.debug("Num of Scanners: " + scanners.size());
for (Scanner scanner : scanners) {
dataLoadRef.tell(new ScanStatsRawDataMessage(scanners.size(), scanner), getSelf());
}
} else if (message instanceof StatsResultMessage) {
StatsResultMessage statsRstMsg = (StatsResultMessage) message;
ColumnConfig columnConfig = statsRstMsg.getColumnConfig();
columnConfigList.set(columnConfig.getColumnNum(), columnConfig);
resultCnt++;
if (resultCnt == columnNumToActorMap.size()) {
log.info("Received " + resultCnt + " messages. Finished Calculating Stats.");
PathFinder pathFinder = new PathFinder(modelConfig);
JSONUtils.writeValue(new File(pathFinder.getColumnConfigPath()), columnConfigList);
getContext().system().shutdown();
}
} else if (message instanceof ExceptionMessage) {
// since some children actors meet some exception, shutdown the system
ExceptionMessage msg = (ExceptionMessage) message;
getContext().system().shutdown();
// and wrapper the exception into Return status
addExceptionIntoCondition(msg.getException());
} else {
unhandled(message);
}
}
Aggregations