use of org.apache.ignite.ml.math.exceptions.preprocessing.UnknownCategorialValueException in project ignite by apache.
the class FrequencyEncoderPreprocessor method apply.
/**
* Applies this preprocessor.
*
* @param k Key.
* @param v Value.
* @return Preprocessed row.
*/
@Override
public LabeledVector apply(K k, V v) {
LabeledVector tmp = basePreprocessor.apply(k, v);
double[] res = new double[tmp.size()];
for (int i = 0; i < res.length; i++) {
Object tmpObj = tmp.getRaw(i);
if (handledIndices.contains(i)) {
if (tmpObj.equals(Double.NaN) && encodingFrequencies[i].containsKey(KEY_FOR_NULL_VALUES))
res[i] = encodingValues[i].get(KEY_FOR_NULL_VALUES);
else if (encodingFrequencies[i].containsKey(tmpObj))
res[i] = encodingFrequencies[i].get(tmpObj);
else
throw new UnknownCategorialValueException(tmpObj.toString());
} else
res[i] = (double) tmpObj;
}
return new LabeledVector(VectorUtils.of(res), tmp.label());
}
use of org.apache.ignite.ml.math.exceptions.preprocessing.UnknownCategorialValueException in project ignite by apache.
the class LabelEncoderPreprocessor method apply.
/**
* Applies this preprocessor.
*
* @param k Key.
* @param v Value.
* @return Preprocessed row.
*/
@Override
public LabeledVector apply(K k, V v) {
LabeledVector tmp = basePreprocessor.apply(k, v);
double res;
Object tmpObj = tmp.label();
if (tmpObj.equals(Double.NaN) && labelFrequencies.containsKey(KEY_FOR_NULL_VALUES))
res = labelFrequencies.get(KEY_FOR_NULL_VALUES);
else if (labelFrequencies.containsKey(tmpObj))
res = labelFrequencies.get(tmpObj);
else
throw new UnknownCategorialValueException(tmpObj.toString());
return new LabeledVector(tmp.features(), res);
}
use of org.apache.ignite.ml.math.exceptions.preprocessing.UnknownCategorialValueException in project ignite by apache.
the class StringEncoderPreprocessor method apply.
/**
* Applies this preprocessor.
*
* @param k Key.
* @param v Value.
* @return Preprocessed row.
*/
@Override
public LabeledVector apply(K k, V v) {
LabeledVector tmp = basePreprocessor.apply(k, v);
double[] res = new double[tmp.size()];
for (int i = 0; i < res.length; i++) {
Object tmpObj = tmp.getRaw(i);
if (handledIndices.contains(i)) {
if (tmpObj.equals(Double.NaN) && encodingValues[i].containsKey(KEY_FOR_NULL_VALUES))
res[i] = encodingValues[i].get(KEY_FOR_NULL_VALUES);
else if (encodingValues[i].containsKey(tmpObj))
res[i] = encodingValues[i].get(tmpObj);
else
throw new UnknownCategorialValueException(tmpObj.toString());
} else {
if (tmpObj instanceof Number)
res[i] = (double) tmpObj;
else
throw new IllegalFeatureTypeException(tmpObj.getClass(), tmpObj, Double.class);
}
}
return new LabeledVector(VectorUtils.of(res), tmp.label());
}
use of org.apache.ignite.ml.math.exceptions.preprocessing.UnknownCategorialValueException in project ignite by apache.
the class OneHotEncoderPreprocessor method apply.
/**
* Applies this preprocessor.
*
* @param k Key.
* @param v Value.
* @return Preprocessed row.
*/
@Override
public LabeledVector apply(K k, V v) {
LabeledVector tmp = basePreprocessor.apply(k, v);
int amountOfCategorialFeatures = handledIndices.size();
double[] res = new double[tmp.size() - amountOfCategorialFeatures + getAdditionalSize(encodingValues)];
int categorialFeatureCntr = 0;
int resIdx = 0;
for (int i = 0; i < tmp.size(); i++) {
Object tmpObj = tmp.getRaw(i);
if (handledIndices.contains(i)) {
categorialFeatureCntr++;
if (tmpObj.equals(Double.NaN) && encodingValues[i].containsKey(KEY_FOR_NULL_VALUES)) {
final Integer indexedVal = encodingValues[i].get(KEY_FOR_NULL_VALUES);
res[tmp.size() - amountOfCategorialFeatures + getIdxOffset(categorialFeatureCntr, indexedVal, encodingValues)] = 1.0;
} else {
final String key = String.valueOf(tmpObj);
if (encodingValues[i].containsKey(key)) {
final Integer indexedVal = encodingValues[i].get(key);
res[tmp.size() - amountOfCategorialFeatures + getIdxOffset(categorialFeatureCntr, indexedVal, encodingValues)] = 1.0;
} else
throw new UnknownCategorialValueException(tmpObj.toString());
}
} else {
res[resIdx] = (double) tmpObj;
resIdx++;
}
}
return new LabeledVector(VectorUtils.of(res), tmp.label());
}
use of org.apache.ignite.ml.math.exceptions.preprocessing.UnknownCategorialValueException in project ignite by apache.
the class EncoderTrainerTest method testFitWithUnknownStringValueInTheGivenData.
/**
* Tests {@code fit()} method.
*/
@Test
public void testFitWithUnknownStringValueInTheGivenData() {
Map<Integer, Vector> data = new HashMap<>();
data.put(1, VectorUtils.of(3.0, 0.0));
data.put(2, VectorUtils.of(3.0, 12.0));
data.put(3, VectorUtils.of(3.0, 12.0));
data.put(4, VectorUtils.of(2.0, 45.0));
data.put(5, VectorUtils.of(2.0, 45.0));
data.put(6, VectorUtils.of(14.0, 12.0));
final Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<>(0, 1);
DatasetBuilder<Integer, Vector> datasetBuilder = new LocalDatasetBuilder<>(data, parts);
EncoderTrainer<Integer, Vector> strEncoderTrainer = new EncoderTrainer<Integer, Vector>().withEncoderType(EncoderType.STRING_ENCODER).withEncodedFeature(0).withEncodedFeature(1);
EncoderPreprocessor<Integer, Vector> preprocessor = strEncoderTrainer.fit(TestUtils.testEnvBuilder(), datasetBuilder, vectorizer);
try {
preprocessor.apply(7, new DenseVector(new Serializable[] { "Monday", "September" })).features().asArray();
fail("UnknownCategorialFeatureValue");
} catch (UnknownCategorialValueException e) {
return;
}
fail("UnknownCategorialFeatureValue");
}
Aggregations