use of ml.shifu.shifu.core.dvarsel.CandidatePopulation in project shifu by ShifuML.
the class CandidateGenerator method initSeeds.
public CandidatePopulation initSeeds() {
CandidatePopulation seeds = new CandidatePopulation(iteratorSeedCount);
for (int seedIndex = 0; seedIndex < iteratorSeedCount; seedIndex++) {
List<Integer> variableList = new ArrayList<Integer>(expectVariableCount);
for (int varIndex = 0; varIndex < expectVariableCount; varIndex++) {
variableList.add(randomVariable(rd, variableList));
}
seeds.addCandidateSeed(new CandidateSeed(this.genSeedId(), variableList));
}
return seeds;
}
use of ml.shifu.shifu.core.dvarsel.CandidatePopulation in project shifu by ShifuML.
the class CandidateGeneratorTest method testNextGeneration.
@Test
public void testNextGeneration() throws Exception {
CandidatePopulation seed = generator.initSeeds();
System.out.println(seed);
Random random = new Random();
for (int i = 0; i < EXPECT_ITERATION_COUNT - 1; i++) {
seed = generateNext(seed, random);
System.out.println(seed);
Assert.assertEquals(10, seed.getSeedList().size());
}
}
use of ml.shifu.shifu.core.dvarsel.CandidatePopulation in project shifu by ShifuML.
the class CandidateGeneratorTest method testInitSeeds.
@Test
public void testInitSeeds() throws Exception {
CandidatePopulation seeds = generator.initSeeds();
System.out.println(seeds);
}
use of ml.shifu.shifu.core.dvarsel.CandidatePopulation in project shifu by ShifuML.
the class CandidateGenerator method nextGeneration.
public CandidatePopulation nextGeneration(Iterable<VarSelWorkerResult> workerResults, CandidatePopulation seeds) {
if (hasNoneResults(workerResults)) {
return seeds;
}
List<CandidatePerf> perfs = getIndividual(workerResults);
Collections.sort(perfs, new Comparator<CandidatePerf>() {
@Override
public int compare(CandidatePerf cpa, CandidatePerf cpb) {
return cpa.getVerror() < cpb.getVerror() ? -1 : 1;
}
});
for (int i = 0; i < 5; i++) {
LOG.info("The error rate is {}, the best-{} seed: {} ", perfs.get(i).getVerror(), i, seeds.getSeedById(perfs.get(i).getId()));
}
LOG.info("Worst seed: {}", perfs.get(perfs.size() - 1).toString());
List<CandidatePerf> bestPerfs = perfs.subList(0, getLastBestIndex(perfs) + 1);
List<CandidatePerf> ordinaryPerfs = perfs.subList(getLastBestIndex(perfs) + 1, getFistWorstIndex(perfs));
List<CandidatePerf> worstPerfs = perfs.subList(getFistWorstIndex(perfs), perfs.size());
List<CandidateSeed> bestSeeds = filter(seeds, bestPerfs);
List<CandidateSeed> ordinarySeeds = filter(seeds, ordinaryPerfs);
List<CandidateSeed> worstSeeds = filter(seeds, worstPerfs);
CandidatePopulation result = new CandidatePopulation(iteratorSeedCount);
result.addCandidateSeedList(inherit(bestSeeds));
result.addCandidateSeedList(hybrid(ordinarySeeds));
result.addCandidateSeedList(mutate(worstSeeds));
LOG.debug("new generation:" + result);
return result;
}
Aggregations