Search in sources :

Example 1 with ColumnConfigComparator

use of ml.shifu.shifu.container.obj.ColumnConfig.ColumnConfigComparator in project shifu by ShifuML.

the class VariableSelector method selectByFilter.

// return the list of selected column nums
public List<ColumnConfig> selectByFilter() throws IOException {
    log.info("    - Method: Filter");
    int ptrKs = 0, ptrIv = 0, ptrPareto = 0, cntByForce = 0;
    VariableSelector.setFilterNumberByFilterOutRatio(this.modelConfig, this.columnConfigList);
    log.info("Start Variable Selection...");
    log.info("\t VarSelectEnabled: " + modelConfig.getVarSelectFilterEnabled());
    log.info("\t VarSelectBy: " + modelConfig.getVarSelectFilterBy());
    log.info("\t VarSelectNum: " + modelConfig.getVarSelectFilterNum());
    List<Integer> selectedColumnNumList = new ArrayList<Integer>();
    List<ColumnConfig> ksList = new ArrayList<ColumnConfig>();
    List<ColumnConfig> ivList = new ArrayList<ColumnConfig>();
    List<Tuple> paretoList = new ArrayList<Tuple>();
    Set<NSColumn> candidateColumns = CommonUtils.loadCandidateColumns(modelConfig);
    boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
    int cntSelected = 0;
    for (ColumnConfig config : this.columnConfigList) {
        if (config == null) {
            continue;
        }
        if (config.isMeta() || config.isTarget()) {
            log.info("\t Skip meta, weight or target column: " + config.getColumnName());
        } else if (config.isForceRemove()) {
            log.info("\t ForceRemove: " + config.getColumnName());
        } else if (config.isForceSelect()) {
            log.info("\t ForceSelect: " + config.getColumnName());
            if (config.getMean() == null || config.getStdDev() == null) {
                // TODO - check the mean of categorical variable could be null
                log.info("\t ForceSelect Failed: mean/stdDev must not be null");
            } else {
                selectedColumnNumList.add(config.getColumnNum());
                cntSelected++;
                cntByForce++;
            }
        } else if (!CommonUtils.isGoodCandidate(config, hasCandidates)) {
            log.info("\t Incomplete info(please check KS, IV, Mean, or StdDev fields): " + config.getColumnName() + " or it is not in candidate list");
        } else if (CollectionUtils.isNotEmpty(candidateColumns) && !candidateColumns.contains(new NSColumn(config.getColumnName()))) {
            log.info("\t Not in candidate list, Skip: " + config.getColumnName());
        } else if ((config.isCategorical() && !modelConfig.isCategoricalDisabled()) || config.isNumerical()) {
            ksList.add(config);
            ivList.add(config);
            if (config != null && config.getColumnStats() != null) {
                Double ks = config.getKs();
                Double iv = config.getIv();
                paretoList.add(new Tuple(config.getColumnNum(), ks == null ? 0d : ks, iv == null ? 0d : iv));
            }
        }
    }
    // not enabled filter, so only select forceSelect columns
    if (!this.modelConfig.getVarSelectFilterEnabled()) {
        log.info("Summary:");
        log.info("\tSelected Variables: " + cntSelected);
        if (cntByForce != 0) {
            log.info("\t- By Force: " + cntByForce);
        }
        for (int n : selectedColumnNumList) {
            this.columnConfigList.get(n).setFinalSelect(true);
        }
        return columnConfigList;
    }
    String key = this.modelConfig.getVarSelectFilterBy();
    Collections.sort(ksList, new ColumnConfigComparator("ks"));
    Collections.sort(ivList, new ColumnConfigComparator("iv"));
    List<Tuple> newParetoList = sortByPareto(paretoList);
    int expectedVarNum = Math.min(cntSelected + ksList.size(), modelConfig.getVarSelectFilterNum());
    log.info("Expected selected columns:" + expectedVarNum);
    // reset to false at first.
    resetFinalSelect();
    ColumnConfig config = null;
    while (cntSelected < expectedVarNum) {
        if (key.equalsIgnoreCase("ks")) {
            config = ksList.get(ptrKs);
            selectedColumnNumList.add(config.getColumnNum());
            ptrKs++;
            log.info("\t SelectedByKS=" + config.getKs() + "(Rank=" + ptrKs + "): " + config.getColumnName());
            cntSelected++;
        } else if (key.equalsIgnoreCase("iv")) {
            config = ivList.get(ptrIv);
            selectedColumnNumList.add(config.getColumnNum());
            ptrIv++;
            log.info("\t SelectedByIV=" + config.getIv() + "(Rank=" + ptrIv + "): " + config.getColumnName());
            cntSelected++;
        } else if (key.equalsIgnoreCase("mix")) {
            config = ksList.get(ptrKs);
            if (selectedColumnNumList.contains(config.getColumnNum())) {
                log.info("\t Variable Already Selected: " + config.getColumnName());
                ptrKs++;
            } else {
                selectedColumnNumList.add(config.getColumnNum());
                ptrKs++;
                log.info("\t SelectedByKS=" + config.getKs() + "(Rank=" + ptrKs + "): " + config.getColumnName());
                cntSelected++;
            }
            if (cntSelected == expectedVarNum) {
                break;
            }
            config = ivList.get(ptrIv);
            if (selectedColumnNumList.contains(config.getColumnNum())) {
                log.info("\t Variable Already Selected: " + config.getColumnName());
                ptrIv++;
            } else {
                selectedColumnNumList.add(config.getColumnNum());
                ptrIv++;
                log.info("\t SelectedByIV=" + config.getIv() + "(Rank=" + ptrIv + "): " + config.getColumnName());
                cntSelected++;
            }
        } else if (key.equalsIgnoreCase("pareto")) {
            if (ptrPareto >= newParetoList.size()) {
                config = ksList.get(ptrKs);
                if (selectedColumnNumList.contains(config.getColumnNum())) {
                    log.info("\t Variable Already Selected: " + config.getColumnName());
                } else {
                    selectedColumnNumList.add(config.getColumnNum());
                    log.info("\t SelectedByKS=" + config.getKs() + "(Rank=" + ptrKs + newParetoList.size() + "): " + config.getColumnName());
                    cntSelected++;
                }
                ptrKs++;
            } else {
                int columnNum = newParetoList.get(ptrPareto).columnNum;
                selectedColumnNumList.add(columnNum);
                log.info("\t SelectedByPareto " + columnConfigList.get(columnNum).getColumnName());
                ptrPareto++;
                cntSelected++;
            }
        }
    }
    log.info("Summary:");
    log.info("\t Selected Variables: " + cntSelected);
    if (cntByForce != 0) {
        log.info("\t - By Force: " + cntByForce);
    }
    if (ptrPareto != 0) {
        log.info("\t - By Pareto: " + ptrPareto);
    }
    if (ptrKs != 0) {
        log.info("\t - By KS: " + ptrKs);
    }
    if (ptrIv != 0) {
        log.info("\t - By IV: " + ptrIv);
    }
    // update column config list and set finalSelect to true
    for (int n : selectedColumnNumList) {
        // get ColumnConfig by column id. The id may not the position in array list after support segments
        ColumnConfig columnConfig = CommonUtils.getColumnConfig(this.columnConfigList, n);
        if (columnConfig != null) {
            columnConfig.setFinalSelect(true);
        }
    }
    return columnConfigList;
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) ColumnConfigComparator(ml.shifu.shifu.container.obj.ColumnConfig.ColumnConfigComparator) NSColumn(ml.shifu.shifu.column.NSColumn)

Example 2 with ColumnConfigComparator

use of ml.shifu.shifu.container.obj.ColumnConfig.ColumnConfigComparator in project shifu by ShifuML.

the class JavaBeanTest method testAllJavaBeans.

@Test
public void testAllJavaBeans() throws IntrospectionException, IOException {
    // JavaBeanTester.test(ScanStatsRawDataMessage.class);
    // JavaBeanTester.test(ScanTrainDataMessage.class);
    JavaBeanTester.test(AkkaActorInputMessage.class);
    JavaBeanTester.test(ColumnScoreMessage.class);
    JavaBeanTester.test(EvalResultMessage.class);
    JavaBeanTester.test(RunModelDataMessage.class);
    JavaBeanTester.test(RunModelResultMessage.class);
    // JavaBeanTester.test(ScanEvalDataMessage.class);
    JavaBeanTester.test(StatsPartRawDataMessage.class);
    JavaBeanTester.test(StatsValueObjectMessage.class);
    JavaBeanTester.test(TrainResultMessage.class);
    JavaBeanTester.test(TrainPartDataMessage.class);
    JavaBeanTester.test(NormPartRawDataMessage.class);
    JavaBeanTester.test(NormResultDataMessage.class);
    JavaBeanTester.test(StatsResultMessage.class);
    // JavaBeanTester.test(TrainInstanceMessage.class);
    JavaBeanTester.test(ColumnBinning.class);
    JavaBeanTester.test(ColumnStats.class);
    // JavaBeanTester.test(ColumnConfig.class);
    JavaBeanTester.test(ModelBasicConf.class);
    JavaBeanTester.test(ModelSourceDataConf.class);
    JavaBeanTester.test(ModelStatsConf.class);
    JavaBeanTester.test(ModelVarSelectConf.class);
    JavaBeanTester.test(ModelNormalizeConf.class);
    JavaBeanTester.test(ModelTrainConf.class);
    JavaBeanTester.test(ModelConfig.class);
    JavaBeanTester.test(ValueOption.class);
    JavaBeanTester.test(ValidateResult.class);
    JavaBeanTester.test(MetaItem.class);
    JavaBeanTester.test(MetaGroup.class);
    // JavaBeanTester.test(CaseScoreResult.class);
    JavaBeanTester.test(ModelResultObject.class);
    JavaBeanTester.test(PerformanceObject.class);
    JavaBeanTester.test(ReasonResultObject.class);
    // JavaBeanTester.test(ScoreObject.class);
    JavaBeanTester.test(VariableStoreObject.class);
    JavaBeanTester.test(ValueObject.class);
    JavaBeanTester.test(EvalConfig.class);
    JavaBeanTester.test(ModelInitInputObject.class);
    JavaBeanTester.test(WeightAmplifier.class);
    JavaBeanTester.test(ColumnScoreObject.class);
    JavaBeanTester.test(SourceFile.class);
    ModelResultObjectComparator modelResultObjectComparator = new ModelResultObjectComparator();
    modelResultObjectComparator.compare(new ModelResultObject(1, "2", 3d), new ModelResultObject(1, "2", 3d));
    ColumnConfigComparator cfc = new ColumnConfigComparator("KS");
    ColumnConfig columnConfig = new ColumnConfig();
    columnConfig.setKs(0.0d);
    columnConfig.setIv(0.0d);
    cfc.compare(columnConfig, columnConfig);
    cfc = new ColumnConfigComparator("IV");
    cfc.compare(columnConfig, columnConfig);
    HashMap<String, String> hashMap = new HashMap<String, String>();
    hashMap.put("id", "12");
    new ScoreObject(Arrays.asList(1d), 1);
    new ScoreObject(Arrays.asList(1d), 0);
    ValueObjectComparator voc = new ValueObjectComparator(BinningDataType.Categorical);
    ValueObject valueObject = new ValueObject();
    valueObject.setRaw("123");
    valueObject.setTag("1");
    valueObject.setValue(1.0d);
    voc.compare(valueObject, valueObject);
    ValueObject valueObject2 = new ValueObject();
    valueObject2.setRaw("345");
    valueObject2.setTag("1");
    valueObject2.setValue(2.0d);
    voc.compare(valueObject, valueObject2);
    voc.compare(valueObject, valueObject);
    voc = new ValueObjectComparator(BinningDataType.Numerical);
    voc.compare(valueObject, valueObject2);
    voc.compare(valueObject, valueObject);
    ModelConfig.createInitModelConfig("c", ALGORITHM.NN, "aaa", false);
    BinningObject bo = new BinningObject(DataType.Numerical);
    bo.getNumericalData();
    bo.getScore();
    bo.getTag();
    bo.getType();
    bo.setNumericalData(1d);
    bo.setScore(1.0d);
    bo.setTag("1");
    bo.toString();
    BinningObject bo2 = new BinningObject(DataType.Categorical);
    bo2.getCategoricalData();
    bo2.getScore();
    bo2.getTag();
    bo2.getType();
    bo2.setCategoricalData("111");
    bo2.setScore(1.0d);
    bo2.setTag("1");
    bo2.toString();
    BinningObject bo3 = new BinningObject(DataType.Numerical);
    bo3.getNumericalData();
    bo3.getScore();
    bo3.getTag();
    bo3.getType();
    bo3.setNumericalData(111d);
    bo3.setScore(1.0d);
    bo3.setTag("1");
    bo3.toString();
    BinningObject bo4 = new BinningObject(DataType.Categorical);
    bo4.getCategoricalData();
    bo4.getScore();
    bo4.getTag();
    bo4.getType();
    bo4.setCategoricalData("111");
    bo4.setScore(1.0d);
    bo4.setTag("1");
    bo4.toString();
    VariableObjectComparator vooc = new VariableObjectComparator();
    vooc.compare(bo, bo3);
    vooc.compare(bo2, bo4);
    ExceptionMessage es = new ExceptionMessage(new RuntimeException());
    es.setException(new RuntimeException());
    es.getException();
}
Also used : HashMap(java.util.HashMap) ColumnConfigComparator(ml.shifu.shifu.container.obj.ColumnConfig.ColumnConfigComparator) ModelResultObjectComparator(ml.shifu.shifu.container.ModelResultObject.ModelResultObjectComparator) ValueObjectComparator(ml.shifu.shifu.container.ValueObject.ValueObjectComparator) VariableObjectComparator(ml.shifu.shifu.container.BinningObject.VariableObjectComparator) Test(org.testng.annotations.Test)

Aggregations

ColumnConfigComparator (ml.shifu.shifu.container.obj.ColumnConfig.ColumnConfigComparator)2 HashMap (java.util.HashMap)1 NSColumn (ml.shifu.shifu.column.NSColumn)1 VariableObjectComparator (ml.shifu.shifu.container.BinningObject.VariableObjectComparator)1 ModelResultObjectComparator (ml.shifu.shifu.container.ModelResultObject.ModelResultObjectComparator)1 ValueObjectComparator (ml.shifu.shifu.container.ValueObject.ValueObjectComparator)1 ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)1 Test (org.testng.annotations.Test)1