use of org.apache.ignite.ml.preprocessing.encoding.target.TargetCounter in project ignite by apache.
the class EncoderTrainer method updateTargetCountersForNextRow.
/**
* Updates frequencies by values and features.
*
* @param row Feature vector.
* @param targetCounters Holds the frequencies of categories by values and features.
* @return target counter.
*/
private TargetCounter[] updateTargetCountersForNextRow(LabeledVector row, TargetCounter[] targetCounters) {
if (targetCounters == null)
targetCounters = initializeTargetCounters(row);
else
assert targetCounters.length == row.size() : "Base preprocessor must return exactly " + targetCounters.length + " features";
double targetValue = row.features().get(targetLabelIndex);
for (int i = 0; i < targetCounters.length; i++) {
if (handledIndices.contains(i)) {
String strVal;
Object featureVal = row.features().getRaw(i);
if (featureVal.equals(Double.NaN)) {
strVal = EncoderPreprocessor.KEY_FOR_NULL_VALUES;
row.features().setRaw(i, strVal);
} else if (featureVal instanceof String)
strVal = (String) featureVal;
else if (featureVal instanceof Number)
strVal = String.valueOf(featureVal);
else if (featureVal instanceof Boolean)
strVal = String.valueOf(featureVal);
else
throw new RuntimeException("The type " + featureVal.getClass() + " is not supported for the feature values.");
TargetCounter targetCounter = targetCounters[i];
targetCounter.setTargetCount(targetCounter.getTargetCount() + 1);
targetCounter.setTargetSum(targetCounter.getTargetSum() + targetValue);
Map<String, Long> categoryCounts = targetCounter.getCategoryCounts();
if (categoryCounts.containsKey(strVal)) {
categoryCounts.put(strVal, categoryCounts.get(strVal) + 1);
} else {
categoryCounts.put(strVal, 1L);
}
Map<String, Double> categoryTargetSum = targetCounter.getCategoryTargetSum();
if (categoryTargetSum.containsKey(strVal)) {
categoryTargetSum.put(strVal, categoryTargetSum.get(strVal) + targetValue);
} else {
categoryTargetSum.put(strVal, targetValue);
}
}
}
return targetCounters;
}
use of org.apache.ignite.ml.preprocessing.encoding.target.TargetCounter in project ignite by apache.
the class EncoderTrainer method calculateTargetEncodingFrequencies.
/**
* Calculates encoding frequencies as avarage category target on amount of rows in dataset.
*
* NOTE: The amount of rows is calculated as sum of absolute frequencies.
*
* @param dataset Dataset.
* @return Encoding frequency for each feature.
*/
private TargetEncodingMeta[] calculateTargetEncodingFrequencies(Dataset<EmptyContext, EncoderPartitionData> dataset) {
TargetCounter[] targetCounters = dataset.compute(EncoderPartitionData::targetCounters, (a, b) -> {
if (a == null)
return b;
if (b == null)
return a;
assert a.length == b.length;
for (int i = 0; i < a.length; i++) {
if (handledIndices.contains(i)) {
int finalI = i;
b[i].setTargetSum(a[i].getTargetSum() + b[i].getTargetSum());
b[i].setTargetCount(a[i].getTargetCount() + b[i].getTargetCount());
a[i].getCategoryCounts().forEach((k, v) -> b[finalI].getCategoryCounts().merge(k, v, Long::sum));
a[i].getCategoryTargetSum().forEach((k, v) -> b[finalI].getCategoryTargetSum().merge(k, v, Double::sum));
}
}
return b;
});
TargetEncodingMeta[] targetEncodingMetas = new TargetEncodingMeta[targetCounters.length];
for (int i = 0; i < targetCounters.length; i++) {
if (handledIndices.contains(i)) {
TargetCounter targetCounter = targetCounters[i];
targetEncodingMetas[i] = new TargetEncodingMeta().withGlobalMean(targetCounter.getTargetSum() / targetCounter.getTargetCount()).withCategoryMean(calculateCategoryTargetEncodingFrequency(targetCounter));
}
}
return targetEncodingMetas;
}
Aggregations