use of org.tribuo.clustering.ClusterID in project ml-commons by opensearch-project.
the class TribuoUtil method generateDataset.
/**
* Generate tribuo dataset from data frame.
* @param dataFrame features data
* @param outputFactory the tribuo output factory
* @param desc description for tribuo provenance
* @param outputType the tribuo output type
* @return tribuo dataset
*/
public static <T extends Output<T>> MutableDataset<T> generateDataset(DataFrame dataFrame, OutputFactory<T> outputFactory, String desc, TribuoOutputType outputType) {
List<Example<T>> dataset = new ArrayList<>();
Tuple<String[], double[][]> featureNamesValues = transformDataFrame(dataFrame);
ArrayExample<T> example;
for (int i = 0; i < dataFrame.size(); ++i) {
switch(outputType) {
case CLUSTERID:
example = new ArrayExample<>((T) new ClusterID(ClusterID.UNASSIGNED), featureNamesValues.v1(), featureNamesValues.v2()[i]);
break;
case REGRESSOR:
// Create single dimension tribuo regressor with name DIM-0 and value double NaN.
example = new ArrayExample<>((T) new Regressor("DIM-0", Double.NaN), featureNamesValues.v1(), featureNamesValues.v2()[i]);
break;
case ANOMALY_DETECTION_LIBSVM:
// Why we set default event type as EXPECTED(non-anomalous)
// 1. For training data, Tribuo LibSVMAnomalyTrainer only supports EXPECTED events at training time.
// 2. For prediction data, we treat the data as non-anomalous by default as Tribuo lib don't accept UNKNOWN type.
Event.EventType defaultEventType = Event.EventType.EXPECTED;
// TODO: support anomaly labels to evaluate prediction result
example = new ArrayExample<>((T) new Event(defaultEventType), featureNamesValues.v1(), featureNamesValues.v2()[i]);
break;
default:
throw new IllegalArgumentException("unknown type:" + outputType);
}
dataset.add(example);
}
SimpleDataSourceProvenance provenance = new SimpleDataSourceProvenance(desc, outputFactory);
return new MutableDataset<>(new ListDataSource<>(dataset, outputFactory, provenance));
}
use of org.tribuo.clustering.ClusterID in project ml-commons by opensearch-project.
the class KMeans method trainAndPredict.
@Override
public MLOutput trainAndPredict(DataFrame dataFrame) {
MutableDataset<ClusterID> trainDataset = TribuoUtil.generateDataset(dataFrame, new ClusteringFactory(), "KMeans training and predicting data from opensearch", TribuoOutputType.CLUSTERID);
Integer centroids = Optional.ofNullable(parameters.getCentroids()).orElse(DEFAULT_CENTROIDS);
Integer iterations = Optional.ofNullable(parameters.getIterations()).orElse(DEFAULT_ITERATIONS);
KMeansTrainer trainer = new KMeansTrainer(centroids, iterations, distance, numThreads, seed);
// won't store model in index
KMeansModel kMeansModel = trainer.train(trainDataset);
List<Prediction<ClusterID>> predictions = kMeansModel.predict(trainDataset);
List<Map<String, Object>> listClusterID = new ArrayList<>();
predictions.forEach(e -> listClusterID.add(Collections.singletonMap("ClusterID", e.getOutput().getID())));
return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(listClusterID)).build();
}
use of org.tribuo.clustering.ClusterID in project ml-commons by opensearch-project.
the class KMeans method predict.
@Override
public MLOutput predict(DataFrame dataFrame, Model model) {
if (model == null) {
throw new IllegalArgumentException("No model found for KMeans prediction.");
}
List<Prediction<ClusterID>> predictions;
MutableDataset<ClusterID> predictionDataset = TribuoUtil.generateDataset(dataFrame, new ClusteringFactory(), "KMeans prediction data from opensearch", TribuoOutputType.CLUSTERID);
KMeansModel kMeansModel = (KMeansModel) ModelSerDeSer.deserialize(model.getContent());
predictions = kMeansModel.predict(predictionDataset);
List<Map<String, Object>> listClusterID = new ArrayList<>();
predictions.forEach(e -> listClusterID.add(Collections.singletonMap("ClusterID", e.getOutput().getID())));
return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(listClusterID)).build();
}
use of org.tribuo.clustering.ClusterID in project ml-commons by opensearch-project.
the class KMeans method train.
@Override
public Model train(DataFrame dataFrame) {
MutableDataset<ClusterID> trainDataset = TribuoUtil.generateDataset(dataFrame, new ClusteringFactory(), "KMeans training data from opensearch", TribuoOutputType.CLUSTERID);
Integer centroids = Optional.ofNullable(parameters.getCentroids()).orElse(DEFAULT_CENTROIDS);
Integer iterations = Optional.ofNullable(parameters.getIterations()).orElse(DEFAULT_ITERATIONS);
KMeansTrainer trainer = new KMeansTrainer(centroids, iterations, distance, numThreads, seed);
KMeansModel kMeansModel = trainer.train(trainDataset);
Model model = new Model();
model.setName(FunctionName.KMEANS.name());
model.setVersion(1);
model.setContent(ModelSerDeSer.serialize(kMeansModel));
return model;
}
use of org.tribuo.clustering.ClusterID in project ml-commons by opensearch-project.
the class TribuoUtilTest method generateDataset.
@SuppressWarnings("unchecked")
@Test
public void generateDataset() {
MutableDataset<ClusterID> dataset = TribuoUtil.generateDataset(dataFrame, new ClusteringFactory(), "test", TribuoOutputType.CLUSTERID);
List<Example<ClusterID>> examples = dataset.getData();
Assert.assertEquals(rawData.length, examples.size());
for (int i = 0; i < rawData.length; ++i) {
ArrayExample arrayExample = (ArrayExample) examples.get(i);
Iterator<Feature> iterator = arrayExample.iterator();
int idx = 1;
while (iterator.hasNext()) {
Feature feature = iterator.next();
Assert.assertEquals("f" + idx, feature.getName());
Assert.assertEquals(i + idx / 10.0, feature.getValue(), 0.01);
++idx;
}
}
}
Aggregations