use of ml.shifu.shifu.core.dvarsel.CandidateSeed in project shifu by ShifuML.
the class WrapperMasterConductor method voteBestSeed.
@Override
public CandidateSeed voteBestSeed() {
Map<CandidateSeed, Integer> seedCreditMap = new HashMap<CandidateSeed, Integer>();
for (int i = 0; i < seedCreditQueue.length; i++) {
SeedCredit seedCredit = seedCreditQueue[i];
if (seedCredit != null) {
CandidateSeed candidateSeed = seedCredit.getSeed();
if (!seedCreditMap.containsKey(candidateSeed)) {
seedCreditMap.put(candidateSeed, Integer.valueOf(0));
}
seedCreditMap.put(seedCredit.getSeed(), Integer.valueOf(seedCreditMap.get(candidateSeed) + seedCredit.getCredit()));
}
}
CandidateSeed bestSeed = null;
int maxCredit = Integer.MIN_VALUE;
Iterator<Map.Entry<CandidateSeed, Integer>> iterator = seedCreditMap.entrySet().iterator();
while (iterator.hasNext()) {
Map.Entry<CandidateSeed, Integer> entry = iterator.next();
if (entry.getValue() > maxCredit) {
maxCredit = entry.getValue();
bestSeed = entry.getKey();
}
}
LOG.info("With max credit - {}, the best candidate is {}", maxCredit, bestSeed);
return bestSeed;
}
use of ml.shifu.shifu.core.dvarsel.CandidateSeed in project shifu by ShifuML.
the class CandidateGenerator method hybrid.
private List<CandidateSeed> hybrid(List<CandidateSeed> ordinarySeedList) {
List<CandidateSeed> result = new ArrayList<CandidateSeed>(ordinarySeedList.size());
int childCnt = 0;
while (childCnt < ordinarySeedList.size()) {
CandidateSeed father = ordinarySeedList.get(rd.nextInt(ordinarySeedList.size()));
CandidateSeed mather = ordinarySeedList.get(rd.nextInt(ordinarySeedList.size()));
CandidateSeed child = hybrid(father, mather);
if (child != null) {
result.add(child);
childCnt++;
}
}
return result;
}
use of ml.shifu.shifu.core.dvarsel.CandidateSeed in project shifu by ShifuML.
the class CandidateGenerator method initSeeds.
public CandidatePopulation initSeeds() {
CandidatePopulation seeds = new CandidatePopulation(iteratorSeedCount);
for (int seedIndex = 0; seedIndex < iteratorSeedCount; seedIndex++) {
List<Integer> variableList = new ArrayList<Integer>(expectVariableCount);
for (int varIndex = 0; varIndex < expectVariableCount; varIndex++) {
variableList.add(randomVariable(rd, variableList));
}
seeds.addCandidateSeed(new CandidateSeed(this.genSeedId(), variableList));
}
return seeds;
}
use of ml.shifu.shifu.core.dvarsel.CandidateSeed in project shifu by ShifuML.
the class WrapperWorkerConductorTest method testWrapperConductor.
@Test
public void testWrapperConductor() throws IOException {
ModelConfig modelConfig = CommonUtils.loadModelConfig("src/test/resources/example/cancer-judgement/ModelStore/ModelSet1/ModelConfig.json", RawSourceData.SourceType.LOCAL);
List<ColumnConfig> columnConfigList = CommonUtils.loadColumnConfigList("src/test/resources/example/cancer-judgement/ModelStore/ModelSet1/ColumnConfig.json", RawSourceData.SourceType.LOCAL);
WrapperWorkerConductor wrapper = new WrapperWorkerConductor(modelConfig, columnConfigList);
TrainingDataSet trainingDataSet = genTrainingDataSet(modelConfig, columnConfigList);
wrapper.retainData(trainingDataSet);
List<Integer> columnIdList = new ArrayList<Integer>();
for (int i = 2; i < 30; i++) {
columnIdList.add(i);
}
List<CandidateSeed> seedList = new ArrayList<CandidateSeed>();
for (int i = 0; i < 10; i++) {
seedList.add(new CandidateSeed(0, columnIdList.subList(i + 1, i + 7)));
}
wrapper.consumeMasterResult(new VarSelMasterResult(seedList));
VarSelWorkerResult workerResult = wrapper.generateVarSelResult();
Assert.assertNotNull(workerResult);
Assert.assertTrue(workerResult.getSeedPerfList().size() > 0);
}
use of ml.shifu.shifu.core.dvarsel.CandidateSeed in project shifu by ShifuML.
the class CandidateGenerator method nextGeneration.
public CandidatePopulation nextGeneration(Iterable<VarSelWorkerResult> workerResults, CandidatePopulation seeds) {
if (hasNoneResults(workerResults)) {
return seeds;
}
List<CandidatePerf> perfs = getIndividual(workerResults);
Collections.sort(perfs, new Comparator<CandidatePerf>() {
@Override
public int compare(CandidatePerf cpa, CandidatePerf cpb) {
return cpa.getVerror() < cpb.getVerror() ? -1 : 1;
}
});
for (int i = 0; i < 5; i++) {
LOG.info("The error rate is {}, the best-{} seed: {} ", perfs.get(i).getVerror(), i, seeds.getSeedById(perfs.get(i).getId()));
}
LOG.info("Worst seed: {}", perfs.get(perfs.size() - 1).toString());
List<CandidatePerf> bestPerfs = perfs.subList(0, getLastBestIndex(perfs) + 1);
List<CandidatePerf> ordinaryPerfs = perfs.subList(getLastBestIndex(perfs) + 1, getFistWorstIndex(perfs));
List<CandidatePerf> worstPerfs = perfs.subList(getFistWorstIndex(perfs), perfs.size());
List<CandidateSeed> bestSeeds = filter(seeds, bestPerfs);
List<CandidateSeed> ordinarySeeds = filter(seeds, ordinaryPerfs);
List<CandidateSeed> worstSeeds = filter(seeds, worstPerfs);
CandidatePopulation result = new CandidatePopulation(iteratorSeedCount);
result.addCandidateSeedList(inherit(bestSeeds));
result.addCandidateSeedList(hybrid(ordinarySeeds));
result.addCandidateSeedList(mutate(worstSeeds));
LOG.debug("new generation:" + result);
return result;
}
Aggregations