use of org.apache.ignite.ml.preprocessing.encoding.target.TargetEncodingMeta in project ignite by apache.
the class EncoderTrainer method calculateTargetEncodingFrequencies.
/**
* Calculates encoding frequencies as avarage category target on amount of rows in dataset.
*
* NOTE: The amount of rows is calculated as sum of absolute frequencies.
*
* @param dataset Dataset.
* @return Encoding frequency for each feature.
*/
private TargetEncodingMeta[] calculateTargetEncodingFrequencies(Dataset<EmptyContext, EncoderPartitionData> dataset) {
TargetCounter[] targetCounters = dataset.compute(EncoderPartitionData::targetCounters, (a, b) -> {
if (a == null)
return b;
if (b == null)
return a;
assert a.length == b.length;
for (int i = 0; i < a.length; i++) {
if (handledIndices.contains(i)) {
int finalI = i;
b[i].setTargetSum(a[i].getTargetSum() + b[i].getTargetSum());
b[i].setTargetCount(a[i].getTargetCount() + b[i].getTargetCount());
a[i].getCategoryCounts().forEach((k, v) -> b[finalI].getCategoryCounts().merge(k, v, Long::sum));
a[i].getCategoryTargetSum().forEach((k, v) -> b[finalI].getCategoryTargetSum().merge(k, v, Double::sum));
}
}
return b;
});
TargetEncodingMeta[] targetEncodingMetas = new TargetEncodingMeta[targetCounters.length];
for (int i = 0; i < targetCounters.length; i++) {
if (handledIndices.contains(i)) {
TargetCounter targetCounter = targetCounters[i];
targetEncodingMetas[i] = new TargetEncodingMeta().withGlobalMean(targetCounter.getTargetSum() / targetCounter.getTargetCount()).withCategoryMean(calculateCategoryTargetEncodingFrequency(targetCounter));
}
}
return targetEncodingMetas;
}
use of org.apache.ignite.ml.preprocessing.encoding.target.TargetEncodingMeta in project ignite by apache.
the class TargetEncoderPreprocessorTest method testApply.
/**
* Tests {@code apply()} method.
*/
@Test
public void testApply() {
Vector[] data = new Vector[] { new DenseVector(new Serializable[] { "1", "Moscow", "A" }), new DenseVector(new Serializable[] { "2", "Moscow", "B" }), new DenseVector(new Serializable[] { "3", "Moscow", "B" }) };
Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<>(0, 1, 2);
TargetEncoderPreprocessor<Integer, Vector> preprocessor = new TargetEncoderPreprocessor<>(new TargetEncodingMeta[] { // feature 0
new TargetEncodingMeta().withGlobalMean(0.5).withCategoryMean(new HashMap<String, Double>() {
{
// category "1" avg mean = 1.0
put("1", 1.0);
// category "2" avg mean = 0.0
put("2", 0.0);
}
}), // feature 1
new TargetEncodingMeta().withGlobalMean(0.1).withCategoryMean(new HashMap<String, Double>() {
}), // feature 2
new TargetEncodingMeta().withGlobalMean(0.1).withCategoryMean(new HashMap<String, Double>() {
{
// category "A" avg mean 1.0
put("A", 1.0);
// category "B" avg mean 2.0
put("B", 2.0);
}
}) }, vectorizer, new HashSet<Integer>() {
{
add(0);
add(1);
add(2);
}
});
double[][] postProcessedData = new double[][] { { // "1" contains in dict => use category mean 1.0
1.0, // "Moscow" not contains in dict => use global 0.1
0.1, // "A" contains in dict => use category mean 1.0
1.0 }, { // "2" contains in dict => use category mean 0.0
0.0, // "Moscow" not contains in dict => use global 0.1
0.1, // "B" contains in dict => use category mean 2.0
2.0 }, { // "3" not contains in dict => use global mean 0.5
0.5, // "Moscow" not contains in dict => use global 0.1
0.1, // "B" contains in dict => use category mean 2.0
2.0 } };
for (int i = 0; i < data.length; i++) assertArrayEquals(postProcessedData[i], preprocessor.apply(i, data[i]).features().asArray(), 1e-8);
}
Aggregations