use of org.apache.ignite.ml.structures.LabeledVectorSet in project ignite by apache.
the class CollectionsTest method test.
/**
*/
@Test
@SuppressWarnings("unchecked")
public void test() {
test(new VectorizedViewMatrix(new DenseMatrix(2, 2), 1, 1, 1, 1), new VectorizedViewMatrix(new DenseMatrix(3, 2), 2, 1, 1, 1));
specialTest(new ManhattanDistance(), new ManhattanDistance());
specialTest(new HammingDistance(), new HammingDistance());
specialTest(new EuclideanDistance(), new EuclideanDistance());
FeatureMetadata data = new FeatureMetadata("name2");
data.setName("name1");
test(data, new FeatureMetadata("name2"));
test(new DatasetRow<>(new DenseVector()), new DatasetRow<>(new DenseVector(1)));
test(new LabeledVector<>(new DenseVector(), null), new LabeledVector<>(new DenseVector(1), null));
test(new Dataset<DatasetRow<Vector>>(new DatasetRow[] {}, new FeatureMetadata[] {}), new Dataset<DatasetRow<Vector>>(new DatasetRow[] { new DatasetRow() }, new FeatureMetadata[] { new FeatureMetadata() }));
test(new LogisticRegressionModel(new DenseVector(), 1.0), new LogisticRegressionModel(new DenseVector(), 0.5));
test(new KMeansModelFormat(new Vector[] {}, new ManhattanDistance()), new KMeansModelFormat(new Vector[] {}, new HammingDistance()));
test(new KMeansModel(new Vector[] {}, new ManhattanDistance()), new KMeansModel(new Vector[] {}, new HammingDistance()));
test(new SVMLinearClassificationModel(null, 1.0), new SVMLinearClassificationModel(null, 0.5));
test(new ANNClassificationModel(new LabeledVectorSet<>(), new ANNClassificationTrainer.CentroidStat()), new ANNClassificationModel(new LabeledVectorSet<>(1, 1), new ANNClassificationTrainer.CentroidStat()));
test(new ANNModelFormat(1, new ManhattanDistance(), false, new LabeledVectorSet<>(), new ANNClassificationTrainer.CentroidStat()), new ANNModelFormat(2, new ManhattanDistance(), false, new LabeledVectorSet<>(), new ANNClassificationTrainer.CentroidStat()));
}
use of org.apache.ignite.ml.structures.LabeledVectorSet in project ignite by apache.
the class Deltas method updateModel.
/**
* {@inheritDoc}
*/
@Override
protected <K, V> SVMLinearClassificationModel updateModel(SVMLinearClassificationModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
assert datasetBuilder != null;
IgniteFunction<Double, Double> lbTransformer = lb -> {
if (lb == 0.0)
return -1.0;
else
return lb;
};
IgniteFunction<LabeledVector<Double>, LabeledVector<Double>> func = lv -> new LabeledVector<>(lv.features(), lbTransformer.apply(lv.label()));
PatchedPreprocessor<K, V, Double, Double> patchedPreprocessor = new PatchedPreprocessor<>(func, preprocessor);
PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>(patchedPreprocessor);
Vector weights;
try (Dataset<EmptyContext, LabeledVectorSet<LabeledVector>> dataset = datasetBuilder.build(envBuilder, (env, upstream, upstreamSize) -> new EmptyContext(), partDataBuilder, learningEnvironment())) {
if (mdl == null) {
final int cols = dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> {
if (a == null)
return b == null ? 0 : b;
if (b == null)
return a;
return b;
});
final int weightVectorSizeWithIntercept = cols + 1;
weights = initializeWeightsWithZeros(weightVectorSizeWithIntercept);
} else
weights = getStateVector(mdl);
for (int i = 0; i < this.getAmountOfIterations(); i++) {
Vector deltaWeights = calculateUpdates(weights, dataset);
if (deltaWeights == null)
return getLastTrainedModelOrThrowEmptyDatasetException(mdl);
// creates new vector
weights = weights.plus(deltaWeights);
}
} catch (Exception e) {
throw new RuntimeException(e);
}
return new SVMLinearClassificationModel(weights.copyOfRange(1, weights.size()), weights.get(0));
}
use of org.apache.ignite.ml.structures.LabeledVectorSet in project ignite by apache.
the class LabeledDatasetLoader method loadFromTxtFile.
/**
* Datafile should keep class labels in the first column.
*
* @param pathToFile Path to file.
* @param separator Element to tokenize row on separate tokens.
* @param isFallOnBadData Fall on incorrect data if true.
* @return Labeled Dataset parsed from file.
*/
public static LabeledVectorSet loadFromTxtFile(Path pathToFile, String separator, boolean isFallOnBadData) throws IOException {
Stream<String> stream = Files.lines(pathToFile);
List<String> list = new ArrayList<>();
stream.forEach(list::add);
final int rowSize = list.size();
List<Double> labels = new ArrayList<>();
List<Vector> vectors = new ArrayList<>();
if (rowSize > 0) {
final int colSize = getColumnSize(separator, list) - 1;
if (colSize > 0) {
for (int i = 0; i < rowSize; i++) {
Double clsLb;
String[] rowData = list.get(i).split(separator);
try {
clsLb = Double.parseDouble(rowData[0]);
Vector vec = parseFeatures(pathToFile, isFallOnBadData, colSize, i, rowData);
labels.add(clsLb);
vectors.add(vec);
} catch (NumberFormatException e) {
if (isFallOnBadData)
throw new FileParsingException(rowData[0], i, pathToFile);
}
}
LabeledVector[] data = new LabeledVector[vectors.size()];
for (int i = 0; i < vectors.size(); i++) data[i] = new LabeledVector(vectors.get(i), labels.get(i));
return new LabeledVectorSet(data, colSize);
} else
throw new NoDataException("File should contain first row with data");
} else
throw new EmptyFileException(pathToFile.toString());
}
use of org.apache.ignite.ml.structures.LabeledVectorSet in project ignite by apache.
the class LocalModelsTest method importExportANNModelTest.
/**
*/
@Test
public void importExportANNModelTest() throws IOException {
executeModelTest(mdlFilePath -> {
final LabeledVectorSet<LabeledVector> centers = new LabeledVectorSet<>();
NNClassificationModel mdl = new ANNClassificationModel(centers, new ANNClassificationTrainer.CentroidStat()).withK(4).withDistanceMeasure(new ManhattanDistance()).withWeighted(true);
Exporter<KNNModelFormat, String> exporter = new FileExporter<>();
mdl.saveModel(exporter, mdlFilePath);
ANNModelFormat load = (ANNModelFormat) exporter.load(mdlFilePath);
Assert.assertNotNull(load);
NNClassificationModel importedMdl = new ANNClassificationModel(load.getCandidates(), new ANNClassificationTrainer.CentroidStat()).withK(load.getK()).withDistanceMeasure(load.getDistanceMeasure()).withWeighted(true);
Assert.assertEquals("", mdl, importedMdl);
return null;
});
}
Aggregations