use of ml.shifu.shifu.core.dvarsel.CandidatePerf in project shifu by ShifuML.
the class CandidateGenerator method getIndividual.
private List<CandidatePerf> getIndividual(Iterable<VarSelWorkerResult> workerResults) {
Map<Integer, List<Double>> errorMap = new HashMap<Integer, List<Double>>();
for (VarSelWorkerResult workerResult : workerResults) {
List<CandidatePerf> seedPerfList = workerResult.getSeedPerfList();
for (CandidatePerf perf : seedPerfList) {
if (!errorMap.containsKey(perf.getId())) {
errorMap.put(perf.getId(), new ArrayList<Double>());
}
errorMap.get(perf.getId()).add(perf.getVerror());
}
}
List<CandidatePerf> perfs = new ArrayList<CandidatePerf>(errorMap.size());
for (Entry<Integer, List<Double>> entry : errorMap.entrySet()) {
double vError = mean(entry.getValue());
perfs.add(new CandidatePerf(entry.getKey(), vError));
}
return perfs;
}
use of ml.shifu.shifu.core.dvarsel.CandidatePerf in project shifu by ShifuML.
the class CandidateGenerator method nextGeneration.
public CandidatePopulation nextGeneration(Iterable<VarSelWorkerResult> workerResults, CandidatePopulation seeds) {
if (hasNoneResults(workerResults)) {
return seeds;
}
List<CandidatePerf> perfs = getIndividual(workerResults);
Collections.sort(perfs, new Comparator<CandidatePerf>() {
@Override
public int compare(CandidatePerf cpa, CandidatePerf cpb) {
return cpa.getVerror() < cpb.getVerror() ? -1 : 1;
}
});
for (int i = 0; i < 5; i++) {
LOG.info("The error rate is {}, the best-{} seed: {} ", perfs.get(i).getVerror(), i, seeds.getSeedById(perfs.get(i).getId()));
}
LOG.info("Worst seed: {}", perfs.get(perfs.size() - 1).toString());
List<CandidatePerf> bestPerfs = perfs.subList(0, getLastBestIndex(perfs) + 1);
List<CandidatePerf> ordinaryPerfs = perfs.subList(getLastBestIndex(perfs) + 1, getFistWorstIndex(perfs));
List<CandidatePerf> worstPerfs = perfs.subList(getFistWorstIndex(perfs), perfs.size());
List<CandidateSeed> bestSeeds = filter(seeds, bestPerfs);
List<CandidateSeed> ordinarySeeds = filter(seeds, ordinaryPerfs);
List<CandidateSeed> worstSeeds = filter(seeds, worstPerfs);
CandidatePopulation result = new CandidatePopulation(iteratorSeedCount);
result.addCandidateSeedList(inherit(bestSeeds));
result.addCandidateSeedList(hybrid(ordinarySeeds));
result.addCandidateSeedList(mutate(worstSeeds));
LOG.debug("new generation:" + result);
return result;
}
Aggregations