use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class CommonUtilsTest method getBinNumTest.
@Test
public void getBinNumTest() {
ColumnConfig config = new ColumnConfig();
config.setColumnName("A");
config.setColumnType(ColumnType.C);
config.setBinCategory(Arrays.asList(new String[] { "2", "1", "3" }));
int rt = BinUtils.getBinNum(config, "2");
Assert.assertTrue(rt == 0);
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class BaggingSubsampleUDFTest method setUp.
@BeforeClass
public void setUp() throws Exception {
File file = new File("udf");
if (!file.exists()) {
FileUtils.forceMkdir(file);
}
ModelConfig modelConfig = CommonUtils.loadModelConfig("src/test/resources/example/cancer-judgement/ModelStore/ModelSet1/ModelConfig.json", SourceType.LOCAL);
List<ColumnConfig> columnConfigList = CommonUtils.loadColumnConfigList("src/test/resources/example/cancer-judgement/ModelStore/ModelSet1/ColumnConfig.json", SourceType.LOCAL);
modelConfig.getTrain().setBaggingNum(1);
;
modelConfig.getTrain().setBaggingSampleRate(2.0);
;
jsonMapper.writerWithDefaultPrettyPrinter().writeValue(new File("udf/ModelConfig.json"), modelConfig);
jsonMapper.writerWithDefaultPrettyPrinter().writeValue(new File("udf/ColumnConfig.json"), columnConfigList);
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class ColumnConfigTest method testSaveColumnConfig.
// @Test
public void testSaveColumnConfig() throws JsonParseException, JsonMappingException, IOException {
ColumnConfig[] configs = new ColumnConfig[5];
for (int i = 0; i < 5; i++) {
configs[i] = new ColumnConfig();
configs[i].setColumnName("column" + i);
}
ObjectMapper mapper = new ObjectMapper();
mapper.writeValue(new File("src/test/resources/reason_data/test2.json"), configs);
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class ColumnConfigTest method testLoadColumnConfig.
// @Test
public void testLoadColumnConfig() throws JsonParseException, JsonMappingException, IOException {
ObjectMapper mapper = new ObjectMapper();
ColumnConfig[] configArray = mapper.readValue(new File("src/test/resources/reason_data/test2.json"), ColumnConfig[].class);
List<ColumnConfig> configs = Arrays.asList(configArray);
ColumnConfig config = configs.get(1);
System.out.println(config.getColumnName());
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class VariableSelector method selectByFilter.
// return the list of selected column nums
public List<ColumnConfig> selectByFilter() throws IOException {
log.info(" - Method: Filter");
int ptrKs = 0, ptrIv = 0, ptrPareto = 0, cntByForce = 0;
VariableSelector.setFilterNumberByFilterOutRatio(this.modelConfig, this.columnConfigList);
log.info("Start Variable Selection...");
log.info("\t VarSelectEnabled: " + modelConfig.getVarSelectFilterEnabled());
log.info("\t VarSelectBy: " + modelConfig.getVarSelectFilterBy());
log.info("\t VarSelectNum: " + modelConfig.getVarSelectFilterNum());
List<Integer> selectedColumnNumList = new ArrayList<Integer>();
List<ColumnConfig> ksList = new ArrayList<ColumnConfig>();
List<ColumnConfig> ivList = new ArrayList<ColumnConfig>();
List<Tuple> paretoList = new ArrayList<Tuple>();
Set<NSColumn> candidateColumns = CommonUtils.loadCandidateColumns(modelConfig);
boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
int cntSelected = 0;
for (ColumnConfig config : this.columnConfigList) {
if (config == null) {
continue;
}
if (config.isMeta() || config.isTarget()) {
log.info("\t Skip meta, weight or target column: " + config.getColumnName());
} else if (config.isForceRemove()) {
log.info("\t ForceRemove: " + config.getColumnName());
} else if (config.isForceSelect()) {
log.info("\t ForceSelect: " + config.getColumnName());
if (config.getMean() == null || config.getStdDev() == null) {
// TODO - check the mean of categorical variable could be null
log.info("\t ForceSelect Failed: mean/stdDev must not be null");
} else {
selectedColumnNumList.add(config.getColumnNum());
cntSelected++;
cntByForce++;
}
} else if (!CommonUtils.isGoodCandidate(config, hasCandidates)) {
log.info("\t Incomplete info(please check KS, IV, Mean, or StdDev fields): " + config.getColumnName() + " or it is not in candidate list");
} else if (CollectionUtils.isNotEmpty(candidateColumns) && !candidateColumns.contains(new NSColumn(config.getColumnName()))) {
log.info("\t Not in candidate list, Skip: " + config.getColumnName());
} else if ((config.isCategorical() && !modelConfig.isCategoricalDisabled()) || config.isNumerical()) {
ksList.add(config);
ivList.add(config);
if (config != null && config.getColumnStats() != null) {
Double ks = config.getKs();
Double iv = config.getIv();
paretoList.add(new Tuple(config.getColumnNum(), ks == null ? 0d : ks, iv == null ? 0d : iv));
}
}
}
// not enabled filter, so only select forceSelect columns
if (!this.modelConfig.getVarSelectFilterEnabled()) {
log.info("Summary:");
log.info("\tSelected Variables: " + cntSelected);
if (cntByForce != 0) {
log.info("\t- By Force: " + cntByForce);
}
for (int n : selectedColumnNumList) {
this.columnConfigList.get(n).setFinalSelect(true);
}
return columnConfigList;
}
String key = this.modelConfig.getVarSelectFilterBy();
Collections.sort(ksList, new ColumnConfigComparator("ks"));
Collections.sort(ivList, new ColumnConfigComparator("iv"));
List<Tuple> newParetoList = sortByPareto(paretoList);
int expectedVarNum = Math.min(cntSelected + ksList.size(), modelConfig.getVarSelectFilterNum());
log.info("Expected selected columns:" + expectedVarNum);
// reset to false at first.
resetFinalSelect();
ColumnConfig config = null;
while (cntSelected < expectedVarNum) {
if (key.equalsIgnoreCase("ks")) {
config = ksList.get(ptrKs);
selectedColumnNumList.add(config.getColumnNum());
ptrKs++;
log.info("\t SelectedByKS=" + config.getKs() + "(Rank=" + ptrKs + "): " + config.getColumnName());
cntSelected++;
} else if (key.equalsIgnoreCase("iv")) {
config = ivList.get(ptrIv);
selectedColumnNumList.add(config.getColumnNum());
ptrIv++;
log.info("\t SelectedByIV=" + config.getIv() + "(Rank=" + ptrIv + "): " + config.getColumnName());
cntSelected++;
} else if (key.equalsIgnoreCase("mix")) {
config = ksList.get(ptrKs);
if (selectedColumnNumList.contains(config.getColumnNum())) {
log.info("\t Variable Already Selected: " + config.getColumnName());
ptrKs++;
} else {
selectedColumnNumList.add(config.getColumnNum());
ptrKs++;
log.info("\t SelectedByKS=" + config.getKs() + "(Rank=" + ptrKs + "): " + config.getColumnName());
cntSelected++;
}
if (cntSelected == expectedVarNum) {
break;
}
config = ivList.get(ptrIv);
if (selectedColumnNumList.contains(config.getColumnNum())) {
log.info("\t Variable Already Selected: " + config.getColumnName());
ptrIv++;
} else {
selectedColumnNumList.add(config.getColumnNum());
ptrIv++;
log.info("\t SelectedByIV=" + config.getIv() + "(Rank=" + ptrIv + "): " + config.getColumnName());
cntSelected++;
}
} else if (key.equalsIgnoreCase("pareto")) {
if (ptrPareto >= newParetoList.size()) {
config = ksList.get(ptrKs);
if (selectedColumnNumList.contains(config.getColumnNum())) {
log.info("\t Variable Already Selected: " + config.getColumnName());
} else {
selectedColumnNumList.add(config.getColumnNum());
log.info("\t SelectedByKS=" + config.getKs() + "(Rank=" + ptrKs + newParetoList.size() + "): " + config.getColumnName());
cntSelected++;
}
ptrKs++;
} else {
int columnNum = newParetoList.get(ptrPareto).columnNum;
selectedColumnNumList.add(columnNum);
log.info("\t SelectedByPareto " + columnConfigList.get(columnNum).getColumnName());
ptrPareto++;
cntSelected++;
}
}
}
log.info("Summary:");
log.info("\t Selected Variables: " + cntSelected);
if (cntByForce != 0) {
log.info("\t - By Force: " + cntByForce);
}
if (ptrPareto != 0) {
log.info("\t - By Pareto: " + ptrPareto);
}
if (ptrKs != 0) {
log.info("\t - By KS: " + ptrKs);
}
if (ptrIv != 0) {
log.info("\t - By IV: " + ptrIv);
}
// update column config list and set finalSelect to true
for (int n : selectedColumnNumList) {
// get ColumnConfig by column id. The id may not the position in array list after support segments
ColumnConfig columnConfig = CommonUtils.getColumnConfig(this.columnConfigList, n);
if (columnConfig != null) {
columnConfig.setFinalSelect(true);
}
}
return columnConfigList;
}
Aggregations