Search in sources :

Example 1 with CandidateSeed

use of ml.shifu.shifu.core.dvarsel.CandidateSeed in project shifu by ShifuML.

the class WrapperMasterConductor method voteBestSeed.

@Override
public CandidateSeed voteBestSeed() {
    Map<CandidateSeed, Integer> seedCreditMap = new HashMap<CandidateSeed, Integer>();
    for (int i = 0; i < seedCreditQueue.length; i++) {
        SeedCredit seedCredit = seedCreditQueue[i];
        if (seedCredit != null) {
            CandidateSeed candidateSeed = seedCredit.getSeed();
            if (!seedCreditMap.containsKey(candidateSeed)) {
                seedCreditMap.put(candidateSeed, Integer.valueOf(0));
            }
            seedCreditMap.put(seedCredit.getSeed(), Integer.valueOf(seedCreditMap.get(candidateSeed) + seedCredit.getCredit()));
        }
    }
    CandidateSeed bestSeed = null;
    int maxCredit = Integer.MIN_VALUE;
    Iterator<Map.Entry<CandidateSeed, Integer>> iterator = seedCreditMap.entrySet().iterator();
    while (iterator.hasNext()) {
        Map.Entry<CandidateSeed, Integer> entry = iterator.next();
        if (entry.getValue() > maxCredit) {
            maxCredit = entry.getValue();
            bestSeed = entry.getKey();
        }
    }
    LOG.info("With max credit - {}, the best candidate is {}", maxCredit, bestSeed);
    return bestSeed;
}
Also used : HashMap(java.util.HashMap) HashMap(java.util.HashMap) Map(java.util.Map) CandidateSeed(ml.shifu.shifu.core.dvarsel.CandidateSeed)

Example 2 with CandidateSeed

use of ml.shifu.shifu.core.dvarsel.CandidateSeed in project shifu by ShifuML.

the class CandidateGenerator method hybrid.

private List<CandidateSeed> hybrid(List<CandidateSeed> ordinarySeedList) {
    List<CandidateSeed> result = new ArrayList<CandidateSeed>(ordinarySeedList.size());
    int childCnt = 0;
    while (childCnt < ordinarySeedList.size()) {
        CandidateSeed father = ordinarySeedList.get(rd.nextInt(ordinarySeedList.size()));
        CandidateSeed mather = ordinarySeedList.get(rd.nextInt(ordinarySeedList.size()));
        CandidateSeed child = hybrid(father, mather);
        if (child != null) {
            result.add(child);
            childCnt++;
        }
    }
    return result;
}
Also used : CandidateSeed(ml.shifu.shifu.core.dvarsel.CandidateSeed)

Example 3 with CandidateSeed

use of ml.shifu.shifu.core.dvarsel.CandidateSeed in project shifu by ShifuML.

the class CandidateGenerator method initSeeds.

public CandidatePopulation initSeeds() {
    CandidatePopulation seeds = new CandidatePopulation(iteratorSeedCount);
    for (int seedIndex = 0; seedIndex < iteratorSeedCount; seedIndex++) {
        List<Integer> variableList = new ArrayList<Integer>(expectVariableCount);
        for (int varIndex = 0; varIndex < expectVariableCount; varIndex++) {
            variableList.add(randomVariable(rd, variableList));
        }
        seeds.addCandidateSeed(new CandidateSeed(this.genSeedId(), variableList));
    }
    return seeds;
}
Also used : CandidatePopulation(ml.shifu.shifu.core.dvarsel.CandidatePopulation) CandidateSeed(ml.shifu.shifu.core.dvarsel.CandidateSeed)

Example 4 with CandidateSeed

use of ml.shifu.shifu.core.dvarsel.CandidateSeed in project shifu by ShifuML.

the class WrapperWorkerConductorTest method testWrapperConductor.

@Test
public void testWrapperConductor() throws IOException {
    ModelConfig modelConfig = CommonUtils.loadModelConfig("src/test/resources/example/cancer-judgement/ModelStore/ModelSet1/ModelConfig.json", RawSourceData.SourceType.LOCAL);
    List<ColumnConfig> columnConfigList = CommonUtils.loadColumnConfigList("src/test/resources/example/cancer-judgement/ModelStore/ModelSet1/ColumnConfig.json", RawSourceData.SourceType.LOCAL);
    WrapperWorkerConductor wrapper = new WrapperWorkerConductor(modelConfig, columnConfigList);
    TrainingDataSet trainingDataSet = genTrainingDataSet(modelConfig, columnConfigList);
    wrapper.retainData(trainingDataSet);
    List<Integer> columnIdList = new ArrayList<Integer>();
    for (int i = 2; i < 30; i++) {
        columnIdList.add(i);
    }
    List<CandidateSeed> seedList = new ArrayList<CandidateSeed>();
    for (int i = 0; i < 10; i++) {
        seedList.add(new CandidateSeed(0, columnIdList.subList(i + 1, i + 7)));
    }
    wrapper.consumeMasterResult(new VarSelMasterResult(seedList));
    VarSelWorkerResult workerResult = wrapper.generateVarSelResult();
    Assert.assertNotNull(workerResult);
    Assert.assertTrue(workerResult.getSeedPerfList().size() > 0);
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) ArrayList(java.util.ArrayList) VarSelWorkerResult(ml.shifu.shifu.core.dvarsel.VarSelWorkerResult) CandidateSeed(ml.shifu.shifu.core.dvarsel.CandidateSeed) ModelConfig(ml.shifu.shifu.container.obj.ModelConfig) VarSelMasterResult(ml.shifu.shifu.core.dvarsel.VarSelMasterResult) TrainingDataSet(ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet) Test(org.testng.annotations.Test)

Example 5 with CandidateSeed

use of ml.shifu.shifu.core.dvarsel.CandidateSeed in project shifu by ShifuML.

the class CandidateGenerator method nextGeneration.

public CandidatePopulation nextGeneration(Iterable<VarSelWorkerResult> workerResults, CandidatePopulation seeds) {
    if (hasNoneResults(workerResults)) {
        return seeds;
    }
    List<CandidatePerf> perfs = getIndividual(workerResults);
    Collections.sort(perfs, new Comparator<CandidatePerf>() {

        @Override
        public int compare(CandidatePerf cpa, CandidatePerf cpb) {
            return cpa.getVerror() < cpb.getVerror() ? -1 : 1;
        }
    });
    for (int i = 0; i < 5; i++) {
        LOG.info("The error rate is {}, the best-{} seed: {} ", perfs.get(i).getVerror(), i, seeds.getSeedById(perfs.get(i).getId()));
    }
    LOG.info("Worst seed: {}", perfs.get(perfs.size() - 1).toString());
    List<CandidatePerf> bestPerfs = perfs.subList(0, getLastBestIndex(perfs) + 1);
    List<CandidatePerf> ordinaryPerfs = perfs.subList(getLastBestIndex(perfs) + 1, getFistWorstIndex(perfs));
    List<CandidatePerf> worstPerfs = perfs.subList(getFistWorstIndex(perfs), perfs.size());
    List<CandidateSeed> bestSeeds = filter(seeds, bestPerfs);
    List<CandidateSeed> ordinarySeeds = filter(seeds, ordinaryPerfs);
    List<CandidateSeed> worstSeeds = filter(seeds, worstPerfs);
    CandidatePopulation result = new CandidatePopulation(iteratorSeedCount);
    result.addCandidateSeedList(inherit(bestSeeds));
    result.addCandidateSeedList(hybrid(ordinarySeeds));
    result.addCandidateSeedList(mutate(worstSeeds));
    LOG.debug("new generation:" + result);
    return result;
}
Also used : CandidatePerf(ml.shifu.shifu.core.dvarsel.CandidatePerf) CandidatePopulation(ml.shifu.shifu.core.dvarsel.CandidatePopulation) CandidateSeed(ml.shifu.shifu.core.dvarsel.CandidateSeed)

Aggregations

CandidateSeed (ml.shifu.shifu.core.dvarsel.CandidateSeed)5 CandidatePopulation (ml.shifu.shifu.core.dvarsel.CandidatePopulation)2 ArrayList (java.util.ArrayList)1 HashMap (java.util.HashMap)1 Map (java.util.Map)1 ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)1 ModelConfig (ml.shifu.shifu.container.obj.ModelConfig)1 CandidatePerf (ml.shifu.shifu.core.dvarsel.CandidatePerf)1 VarSelMasterResult (ml.shifu.shifu.core.dvarsel.VarSelMasterResult)1 VarSelWorkerResult (ml.shifu.shifu.core.dvarsel.VarSelWorkerResult)1 TrainingDataSet (ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet)1 Test (org.testng.annotations.Test)1