use of smile.gap.GeneticAlgorithm in project smile by haifengl.
the class GAFeatureSelection method learn.
/**
* Genetic algorithm based feature selection for classification.
* @param size the population size of Genetic Algorithm.
* @param generation the maximum number of iterations.
* @param trainer classifier trainer.
* @param measure classification measure as the chromosome fitness measure.
* @param x training instances.
* @param y training labels.
* @param k k-fold cross validation for the evaluation.
* @return bit strings of last generation.
*/
public BitString[] learn(int size, int generation, ClassifierTrainer<double[]> trainer, ClassificationMeasure measure, double[][] x, int[] y, int k) {
if (size <= 0) {
throw new IllegalArgumentException("Invalid population size: " + size);
}
if (k < 2) {
throw new IllegalArgumentException("Invalid k-fold cross validation: " + k);
}
if (x.length != y.length) {
throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
}
int p = x[0].length;
ClassificationFitness fitness = new ClassificationFitness(trainer, measure, x, y, k);
BitString[] seeds = new BitString[size];
for (int i = 0; i < size; i++) {
seeds[i] = new BitString(p, fitness, crossover, crossoverRate, mutationRate);
}
GeneticAlgorithm<BitString> ga = new GeneticAlgorithm<>(seeds, selection);
ga.evolve(generation);
return seeds;
}
use of smile.gap.GeneticAlgorithm in project smile by haifengl.
the class GAFeatureSelection method learn.
/**
* Genetic algorithm based feature selection for regression.
* @param size the population size of Genetic Algorithm.
* @param generation the maximum number of iterations.
* @param trainer regression model trainer.
* @param measure classification measure as the chromosome fitness measure.
* @param x training instances.
* @param y training instance response variable.
* @param k k-fold cross validation for the evaluation.
* @return bit strings of last generation.
*/
public BitString[] learn(int size, int generation, RegressionTrainer<double[]> trainer, RegressionMeasure measure, double[][] x, double[] y, int k) {
if (size <= 0) {
throw new IllegalArgumentException("Invalid population size: " + size);
}
if (k < 2) {
throw new IllegalArgumentException("Invalid k-fold cross validation: " + k);
}
if (x.length != y.length) {
throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
}
int p = x[0].length;
RegressionFitness fitness = new RegressionFitness(trainer, measure, x, y, k);
BitString[] seeds = new BitString[size];
for (int i = 0; i < size; i++) {
seeds[i] = new BitString(p, fitness, crossover, crossoverRate, mutationRate);
}
GeneticAlgorithm<BitString> ga = new GeneticAlgorithm<>(seeds, selection);
ga.evolve(generation);
return seeds;
}
Aggregations