Search in sources :

Example 26 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class CommonUtilsTest method getBinNumTest.

@Test
public void getBinNumTest() {
    ColumnConfig config = new ColumnConfig();
    config.setColumnName("A");
    config.setColumnType(ColumnType.C);
    config.setBinCategory(Arrays.asList(new String[] { "2", "1", "3" }));
    int rt = BinUtils.getBinNum(config, "2");
    Assert.assertTrue(rt == 0);
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) Test(org.testng.annotations.Test)

Example 27 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class BaggingSubsampleUDFTest method setUp.

@BeforeClass
public void setUp() throws Exception {
    File file = new File("udf");
    if (!file.exists()) {
        FileUtils.forceMkdir(file);
    }
    ModelConfig modelConfig = CommonUtils.loadModelConfig("src/test/resources/example/cancer-judgement/ModelStore/ModelSet1/ModelConfig.json", SourceType.LOCAL);
    List<ColumnConfig> columnConfigList = CommonUtils.loadColumnConfigList("src/test/resources/example/cancer-judgement/ModelStore/ModelSet1/ColumnConfig.json", SourceType.LOCAL);
    modelConfig.getTrain().setBaggingNum(1);
    ;
    modelConfig.getTrain().setBaggingSampleRate(2.0);
    ;
    jsonMapper.writerWithDefaultPrettyPrinter().writeValue(new File("udf/ModelConfig.json"), modelConfig);
    jsonMapper.writerWithDefaultPrettyPrinter().writeValue(new File("udf/ColumnConfig.json"), columnConfigList);
}
Also used : ModelConfig(ml.shifu.shifu.container.obj.ModelConfig) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) File(java.io.File) BeforeClass(org.testng.annotations.BeforeClass)

Example 28 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class ColumnConfigTest method testSaveColumnConfig.

// @Test
public void testSaveColumnConfig() throws JsonParseException, JsonMappingException, IOException {
    ColumnConfig[] configs = new ColumnConfig[5];
    for (int i = 0; i < 5; i++) {
        configs[i] = new ColumnConfig();
        configs[i].setColumnName("column" + i);
    }
    ObjectMapper mapper = new ObjectMapper();
    mapper.writeValue(new File("src/test/resources/reason_data/test2.json"), configs);
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) File(java.io.File) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper)

Example 29 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class ColumnConfigTest method testLoadColumnConfig.

// @Test
public void testLoadColumnConfig() throws JsonParseException, JsonMappingException, IOException {
    ObjectMapper mapper = new ObjectMapper();
    ColumnConfig[] configArray = mapper.readValue(new File("src/test/resources/reason_data/test2.json"), ColumnConfig[].class);
    List<ColumnConfig> configs = Arrays.asList(configArray);
    ColumnConfig config = configs.get(1);
    System.out.println(config.getColumnName());
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) File(java.io.File) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper)

Example 30 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class VariableSelector method selectByFilter.

// return the list of selected column nums
public List<ColumnConfig> selectByFilter() throws IOException {
    log.info("    - Method: Filter");
    int ptrKs = 0, ptrIv = 0, ptrPareto = 0, cntByForce = 0;
    VariableSelector.setFilterNumberByFilterOutRatio(this.modelConfig, this.columnConfigList);
    log.info("Start Variable Selection...");
    log.info("\t VarSelectEnabled: " + modelConfig.getVarSelectFilterEnabled());
    log.info("\t VarSelectBy: " + modelConfig.getVarSelectFilterBy());
    log.info("\t VarSelectNum: " + modelConfig.getVarSelectFilterNum());
    List<Integer> selectedColumnNumList = new ArrayList<Integer>();
    List<ColumnConfig> ksList = new ArrayList<ColumnConfig>();
    List<ColumnConfig> ivList = new ArrayList<ColumnConfig>();
    List<Tuple> paretoList = new ArrayList<Tuple>();
    Set<NSColumn> candidateColumns = CommonUtils.loadCandidateColumns(modelConfig);
    boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
    int cntSelected = 0;
    for (ColumnConfig config : this.columnConfigList) {
        if (config == null) {
            continue;
        }
        if (config.isMeta() || config.isTarget()) {
            log.info("\t Skip meta, weight or target column: " + config.getColumnName());
        } else if (config.isForceRemove()) {
            log.info("\t ForceRemove: " + config.getColumnName());
        } else if (config.isForceSelect()) {
            log.info("\t ForceSelect: " + config.getColumnName());
            if (config.getMean() == null || config.getStdDev() == null) {
                // TODO - check the mean of categorical variable could be null
                log.info("\t ForceSelect Failed: mean/stdDev must not be null");
            } else {
                selectedColumnNumList.add(config.getColumnNum());
                cntSelected++;
                cntByForce++;
            }
        } else if (!CommonUtils.isGoodCandidate(config, hasCandidates)) {
            log.info("\t Incomplete info(please check KS, IV, Mean, or StdDev fields): " + config.getColumnName() + " or it is not in candidate list");
        } else if (CollectionUtils.isNotEmpty(candidateColumns) && !candidateColumns.contains(new NSColumn(config.getColumnName()))) {
            log.info("\t Not in candidate list, Skip: " + config.getColumnName());
        } else if ((config.isCategorical() && !modelConfig.isCategoricalDisabled()) || config.isNumerical()) {
            ksList.add(config);
            ivList.add(config);
            if (config != null && config.getColumnStats() != null) {
                Double ks = config.getKs();
                Double iv = config.getIv();
                paretoList.add(new Tuple(config.getColumnNum(), ks == null ? 0d : ks, iv == null ? 0d : iv));
            }
        }
    }
    // not enabled filter, so only select forceSelect columns
    if (!this.modelConfig.getVarSelectFilterEnabled()) {
        log.info("Summary:");
        log.info("\tSelected Variables: " + cntSelected);
        if (cntByForce != 0) {
            log.info("\t- By Force: " + cntByForce);
        }
        for (int n : selectedColumnNumList) {
            this.columnConfigList.get(n).setFinalSelect(true);
        }
        return columnConfigList;
    }
    String key = this.modelConfig.getVarSelectFilterBy();
    Collections.sort(ksList, new ColumnConfigComparator("ks"));
    Collections.sort(ivList, new ColumnConfigComparator("iv"));
    List<Tuple> newParetoList = sortByPareto(paretoList);
    int expectedVarNum = Math.min(cntSelected + ksList.size(), modelConfig.getVarSelectFilterNum());
    log.info("Expected selected columns:" + expectedVarNum);
    // reset to false at first.
    resetFinalSelect();
    ColumnConfig config = null;
    while (cntSelected < expectedVarNum) {
        if (key.equalsIgnoreCase("ks")) {
            config = ksList.get(ptrKs);
            selectedColumnNumList.add(config.getColumnNum());
            ptrKs++;
            log.info("\t SelectedByKS=" + config.getKs() + "(Rank=" + ptrKs + "): " + config.getColumnName());
            cntSelected++;
        } else if (key.equalsIgnoreCase("iv")) {
            config = ivList.get(ptrIv);
            selectedColumnNumList.add(config.getColumnNum());
            ptrIv++;
            log.info("\t SelectedByIV=" + config.getIv() + "(Rank=" + ptrIv + "): " + config.getColumnName());
            cntSelected++;
        } else if (key.equalsIgnoreCase("mix")) {
            config = ksList.get(ptrKs);
            if (selectedColumnNumList.contains(config.getColumnNum())) {
                log.info("\t Variable Already Selected: " + config.getColumnName());
                ptrKs++;
            } else {
                selectedColumnNumList.add(config.getColumnNum());
                ptrKs++;
                log.info("\t SelectedByKS=" + config.getKs() + "(Rank=" + ptrKs + "): " + config.getColumnName());
                cntSelected++;
            }
            if (cntSelected == expectedVarNum) {
                break;
            }
            config = ivList.get(ptrIv);
            if (selectedColumnNumList.contains(config.getColumnNum())) {
                log.info("\t Variable Already Selected: " + config.getColumnName());
                ptrIv++;
            } else {
                selectedColumnNumList.add(config.getColumnNum());
                ptrIv++;
                log.info("\t SelectedByIV=" + config.getIv() + "(Rank=" + ptrIv + "): " + config.getColumnName());
                cntSelected++;
            }
        } else if (key.equalsIgnoreCase("pareto")) {
            if (ptrPareto >= newParetoList.size()) {
                config = ksList.get(ptrKs);
                if (selectedColumnNumList.contains(config.getColumnNum())) {
                    log.info("\t Variable Already Selected: " + config.getColumnName());
                } else {
                    selectedColumnNumList.add(config.getColumnNum());
                    log.info("\t SelectedByKS=" + config.getKs() + "(Rank=" + ptrKs + newParetoList.size() + "): " + config.getColumnName());
                    cntSelected++;
                }
                ptrKs++;
            } else {
                int columnNum = newParetoList.get(ptrPareto).columnNum;
                selectedColumnNumList.add(columnNum);
                log.info("\t SelectedByPareto " + columnConfigList.get(columnNum).getColumnName());
                ptrPareto++;
                cntSelected++;
            }
        }
    }
    log.info("Summary:");
    log.info("\t Selected Variables: " + cntSelected);
    if (cntByForce != 0) {
        log.info("\t - By Force: " + cntByForce);
    }
    if (ptrPareto != 0) {
        log.info("\t - By Pareto: " + ptrPareto);
    }
    if (ptrKs != 0) {
        log.info("\t - By KS: " + ptrKs);
    }
    if (ptrIv != 0) {
        log.info("\t - By IV: " + ptrIv);
    }
    // update column config list and set finalSelect to true
    for (int n : selectedColumnNumList) {
        // get ColumnConfig by column id. The id may not the position in array list after support segments
        ColumnConfig columnConfig = CommonUtils.getColumnConfig(this.columnConfigList, n);
        if (columnConfig != null) {
            columnConfig.setFinalSelect(true);
        }
    }
    return columnConfigList;
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) ColumnConfigComparator(ml.shifu.shifu.container.obj.ColumnConfig.ColumnConfigComparator) NSColumn(ml.shifu.shifu.column.NSColumn)

Aggregations

ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)131 ArrayList (java.util.ArrayList)36 Test (org.testng.annotations.Test)17 IOException (java.io.IOException)16 HashMap (java.util.HashMap)12 Tuple (org.apache.pig.data.Tuple)10 File (java.io.File)8 NSColumn (ml.shifu.shifu.column.NSColumn)8 ModelConfig (ml.shifu.shifu.container.obj.ModelConfig)8 ShifuException (ml.shifu.shifu.exception.ShifuException)8 Path (org.apache.hadoop.fs.Path)8 List (java.util.List)7 Scanner (java.util.Scanner)7 DataBag (org.apache.pig.data.DataBag)7 SourceType (ml.shifu.shifu.container.obj.RawSourceData.SourceType)5 BasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)5 TrainingDataSet (ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet)5 BasicMLData (org.encog.ml.data.basic.BasicMLData)5 BufferedWriter (java.io.BufferedWriter)3 FileInputStream (java.io.FileInputStream)3