use of ml.shifu.shifu.container.obj.ColumnConfig.ColumnConfigComparator in project shifu by ShifuML.
the class VariableSelector method selectByFilter.
// return the list of selected column nums
public List<ColumnConfig> selectByFilter() throws IOException {
log.info(" - Method: Filter");
int ptrKs = 0, ptrIv = 0, ptrPareto = 0, cntByForce = 0;
VariableSelector.setFilterNumberByFilterOutRatio(this.modelConfig, this.columnConfigList);
log.info("Start Variable Selection...");
log.info("\t VarSelectEnabled: " + modelConfig.getVarSelectFilterEnabled());
log.info("\t VarSelectBy: " + modelConfig.getVarSelectFilterBy());
log.info("\t VarSelectNum: " + modelConfig.getVarSelectFilterNum());
List<Integer> selectedColumnNumList = new ArrayList<Integer>();
List<ColumnConfig> ksList = new ArrayList<ColumnConfig>();
List<ColumnConfig> ivList = new ArrayList<ColumnConfig>();
List<Tuple> paretoList = new ArrayList<Tuple>();
Set<NSColumn> candidateColumns = CommonUtils.loadCandidateColumns(modelConfig);
boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
int cntSelected = 0;
for (ColumnConfig config : this.columnConfigList) {
if (config == null) {
continue;
}
if (config.isMeta() || config.isTarget()) {
log.info("\t Skip meta, weight or target column: " + config.getColumnName());
} else if (config.isForceRemove()) {
log.info("\t ForceRemove: " + config.getColumnName());
} else if (config.isForceSelect()) {
log.info("\t ForceSelect: " + config.getColumnName());
if (config.getMean() == null || config.getStdDev() == null) {
// TODO - check the mean of categorical variable could be null
log.info("\t ForceSelect Failed: mean/stdDev must not be null");
} else {
selectedColumnNumList.add(config.getColumnNum());
cntSelected++;
cntByForce++;
}
} else if (!CommonUtils.isGoodCandidate(config, hasCandidates)) {
log.info("\t Incomplete info(please check KS, IV, Mean, or StdDev fields): " + config.getColumnName() + " or it is not in candidate list");
} else if (CollectionUtils.isNotEmpty(candidateColumns) && !candidateColumns.contains(new NSColumn(config.getColumnName()))) {
log.info("\t Not in candidate list, Skip: " + config.getColumnName());
} else if ((config.isCategorical() && !modelConfig.isCategoricalDisabled()) || config.isNumerical()) {
ksList.add(config);
ivList.add(config);
if (config != null && config.getColumnStats() != null) {
Double ks = config.getKs();
Double iv = config.getIv();
paretoList.add(new Tuple(config.getColumnNum(), ks == null ? 0d : ks, iv == null ? 0d : iv));
}
}
}
// not enabled filter, so only select forceSelect columns
if (!this.modelConfig.getVarSelectFilterEnabled()) {
log.info("Summary:");
log.info("\tSelected Variables: " + cntSelected);
if (cntByForce != 0) {
log.info("\t- By Force: " + cntByForce);
}
for (int n : selectedColumnNumList) {
this.columnConfigList.get(n).setFinalSelect(true);
}
return columnConfigList;
}
String key = this.modelConfig.getVarSelectFilterBy();
Collections.sort(ksList, new ColumnConfigComparator("ks"));
Collections.sort(ivList, new ColumnConfigComparator("iv"));
List<Tuple> newParetoList = sortByPareto(paretoList);
int expectedVarNum = Math.min(cntSelected + ksList.size(), modelConfig.getVarSelectFilterNum());
log.info("Expected selected columns:" + expectedVarNum);
// reset to false at first.
resetFinalSelect();
ColumnConfig config = null;
while (cntSelected < expectedVarNum) {
if (key.equalsIgnoreCase("ks")) {
config = ksList.get(ptrKs);
selectedColumnNumList.add(config.getColumnNum());
ptrKs++;
log.info("\t SelectedByKS=" + config.getKs() + "(Rank=" + ptrKs + "): " + config.getColumnName());
cntSelected++;
} else if (key.equalsIgnoreCase("iv")) {
config = ivList.get(ptrIv);
selectedColumnNumList.add(config.getColumnNum());
ptrIv++;
log.info("\t SelectedByIV=" + config.getIv() + "(Rank=" + ptrIv + "): " + config.getColumnName());
cntSelected++;
} else if (key.equalsIgnoreCase("mix")) {
config = ksList.get(ptrKs);
if (selectedColumnNumList.contains(config.getColumnNum())) {
log.info("\t Variable Already Selected: " + config.getColumnName());
ptrKs++;
} else {
selectedColumnNumList.add(config.getColumnNum());
ptrKs++;
log.info("\t SelectedByKS=" + config.getKs() + "(Rank=" + ptrKs + "): " + config.getColumnName());
cntSelected++;
}
if (cntSelected == expectedVarNum) {
break;
}
config = ivList.get(ptrIv);
if (selectedColumnNumList.contains(config.getColumnNum())) {
log.info("\t Variable Already Selected: " + config.getColumnName());
ptrIv++;
} else {
selectedColumnNumList.add(config.getColumnNum());
ptrIv++;
log.info("\t SelectedByIV=" + config.getIv() + "(Rank=" + ptrIv + "): " + config.getColumnName());
cntSelected++;
}
} else if (key.equalsIgnoreCase("pareto")) {
if (ptrPareto >= newParetoList.size()) {
config = ksList.get(ptrKs);
if (selectedColumnNumList.contains(config.getColumnNum())) {
log.info("\t Variable Already Selected: " + config.getColumnName());
} else {
selectedColumnNumList.add(config.getColumnNum());
log.info("\t SelectedByKS=" + config.getKs() + "(Rank=" + ptrKs + newParetoList.size() + "): " + config.getColumnName());
cntSelected++;
}
ptrKs++;
} else {
int columnNum = newParetoList.get(ptrPareto).columnNum;
selectedColumnNumList.add(columnNum);
log.info("\t SelectedByPareto " + columnConfigList.get(columnNum).getColumnName());
ptrPareto++;
cntSelected++;
}
}
}
log.info("Summary:");
log.info("\t Selected Variables: " + cntSelected);
if (cntByForce != 0) {
log.info("\t - By Force: " + cntByForce);
}
if (ptrPareto != 0) {
log.info("\t - By Pareto: " + ptrPareto);
}
if (ptrKs != 0) {
log.info("\t - By KS: " + ptrKs);
}
if (ptrIv != 0) {
log.info("\t - By IV: " + ptrIv);
}
// update column config list and set finalSelect to true
for (int n : selectedColumnNumList) {
// get ColumnConfig by column id. The id may not the position in array list after support segments
ColumnConfig columnConfig = CommonUtils.getColumnConfig(this.columnConfigList, n);
if (columnConfig != null) {
columnConfig.setFinalSelect(true);
}
}
return columnConfigList;
}
use of ml.shifu.shifu.container.obj.ColumnConfig.ColumnConfigComparator in project shifu by ShifuML.
the class JavaBeanTest method testAllJavaBeans.
@Test
public void testAllJavaBeans() throws IntrospectionException, IOException {
// JavaBeanTester.test(ScanStatsRawDataMessage.class);
// JavaBeanTester.test(ScanTrainDataMessage.class);
JavaBeanTester.test(AkkaActorInputMessage.class);
JavaBeanTester.test(ColumnScoreMessage.class);
JavaBeanTester.test(EvalResultMessage.class);
JavaBeanTester.test(RunModelDataMessage.class);
JavaBeanTester.test(RunModelResultMessage.class);
// JavaBeanTester.test(ScanEvalDataMessage.class);
JavaBeanTester.test(StatsPartRawDataMessage.class);
JavaBeanTester.test(StatsValueObjectMessage.class);
JavaBeanTester.test(TrainResultMessage.class);
JavaBeanTester.test(TrainPartDataMessage.class);
JavaBeanTester.test(NormPartRawDataMessage.class);
JavaBeanTester.test(NormResultDataMessage.class);
JavaBeanTester.test(StatsResultMessage.class);
// JavaBeanTester.test(TrainInstanceMessage.class);
JavaBeanTester.test(ColumnBinning.class);
JavaBeanTester.test(ColumnStats.class);
// JavaBeanTester.test(ColumnConfig.class);
JavaBeanTester.test(ModelBasicConf.class);
JavaBeanTester.test(ModelSourceDataConf.class);
JavaBeanTester.test(ModelStatsConf.class);
JavaBeanTester.test(ModelVarSelectConf.class);
JavaBeanTester.test(ModelNormalizeConf.class);
JavaBeanTester.test(ModelTrainConf.class);
JavaBeanTester.test(ModelConfig.class);
JavaBeanTester.test(ValueOption.class);
JavaBeanTester.test(ValidateResult.class);
JavaBeanTester.test(MetaItem.class);
JavaBeanTester.test(MetaGroup.class);
// JavaBeanTester.test(CaseScoreResult.class);
JavaBeanTester.test(ModelResultObject.class);
JavaBeanTester.test(PerformanceObject.class);
JavaBeanTester.test(ReasonResultObject.class);
// JavaBeanTester.test(ScoreObject.class);
JavaBeanTester.test(VariableStoreObject.class);
JavaBeanTester.test(ValueObject.class);
JavaBeanTester.test(EvalConfig.class);
JavaBeanTester.test(ModelInitInputObject.class);
JavaBeanTester.test(WeightAmplifier.class);
JavaBeanTester.test(ColumnScoreObject.class);
JavaBeanTester.test(SourceFile.class);
ModelResultObjectComparator modelResultObjectComparator = new ModelResultObjectComparator();
modelResultObjectComparator.compare(new ModelResultObject(1, "2", 3d), new ModelResultObject(1, "2", 3d));
ColumnConfigComparator cfc = new ColumnConfigComparator("KS");
ColumnConfig columnConfig = new ColumnConfig();
columnConfig.setKs(0.0d);
columnConfig.setIv(0.0d);
cfc.compare(columnConfig, columnConfig);
cfc = new ColumnConfigComparator("IV");
cfc.compare(columnConfig, columnConfig);
HashMap<String, String> hashMap = new HashMap<String, String>();
hashMap.put("id", "12");
new ScoreObject(Arrays.asList(1d), 1);
new ScoreObject(Arrays.asList(1d), 0);
ValueObjectComparator voc = new ValueObjectComparator(BinningDataType.Categorical);
ValueObject valueObject = new ValueObject();
valueObject.setRaw("123");
valueObject.setTag("1");
valueObject.setValue(1.0d);
voc.compare(valueObject, valueObject);
ValueObject valueObject2 = new ValueObject();
valueObject2.setRaw("345");
valueObject2.setTag("1");
valueObject2.setValue(2.0d);
voc.compare(valueObject, valueObject2);
voc.compare(valueObject, valueObject);
voc = new ValueObjectComparator(BinningDataType.Numerical);
voc.compare(valueObject, valueObject2);
voc.compare(valueObject, valueObject);
ModelConfig.createInitModelConfig("c", ALGORITHM.NN, "aaa", false);
BinningObject bo = new BinningObject(DataType.Numerical);
bo.getNumericalData();
bo.getScore();
bo.getTag();
bo.getType();
bo.setNumericalData(1d);
bo.setScore(1.0d);
bo.setTag("1");
bo.toString();
BinningObject bo2 = new BinningObject(DataType.Categorical);
bo2.getCategoricalData();
bo2.getScore();
bo2.getTag();
bo2.getType();
bo2.setCategoricalData("111");
bo2.setScore(1.0d);
bo2.setTag("1");
bo2.toString();
BinningObject bo3 = new BinningObject(DataType.Numerical);
bo3.getNumericalData();
bo3.getScore();
bo3.getTag();
bo3.getType();
bo3.setNumericalData(111d);
bo3.setScore(1.0d);
bo3.setTag("1");
bo3.toString();
BinningObject bo4 = new BinningObject(DataType.Categorical);
bo4.getCategoricalData();
bo4.getScore();
bo4.getTag();
bo4.getType();
bo4.setCategoricalData("111");
bo4.setScore(1.0d);
bo4.setTag("1");
bo4.toString();
VariableObjectComparator vooc = new VariableObjectComparator();
vooc.compare(bo, bo3);
vooc.compare(bo2, bo4);
ExceptionMessage es = new ExceptionMessage(new RuntimeException());
es.setException(new RuntimeException());
es.getException();
}
Aggregations