use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.
the class MLInputTests method testParseKmeansInputDataFrame.
public void testParseKmeansInputDataFrame() throws IOException {
String query = "{\"input_data\":{\"column_metas\":[{\"name\":\"total_sum\",\"column_type\":\"DOUBLE\"},{\"name\":\"is_error\"," + "\"column_type\":\"BOOLEAN\"}],\"rows\":[{\"values\":[{\"column_type\":\"DOUBLE\",\"value\":15}," + "{\"column_type\":\"BOOLEAN\",\"value\":false}]},{\"values\":[{\"column_type\":\"DOUBLE\",\"value\":100}," + "{\"column_type\":\"BOOLEAN\",\"value\":true}]}]}}";
XContentParser parser = parser(query);
MLInput mlInput = MLInput.parse(parser, FunctionName.KMEANS.name());
DataFrameInputDataset inputDataset = (DataFrameInputDataset) mlInput.getInputDataset();
DataFrame dataFrame = inputDataset.getDataFrame();
assertEquals(2, dataFrame.columnMetas().length);
assertEquals(ColumnType.DOUBLE, dataFrame.columnMetas()[0].getColumnType());
assertEquals(ColumnType.BOOLEAN, dataFrame.columnMetas()[1].getColumnType());
assertEquals("total_sum", dataFrame.columnMetas()[0].getName());
assertEquals("is_error", dataFrame.columnMetas()[1].getName());
assertEquals(ColumnType.DOUBLE, dataFrame.getRow(0).getValue(0).columnType());
assertEquals(ColumnType.BOOLEAN, dataFrame.getRow(0).getValue(1).columnType());
assertEquals(15.0, dataFrame.getRow(0).getValue(0).getValue());
assertEquals(false, dataFrame.getRow(0).getValue(1).getValue());
}
use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.
the class MLInputDatasetHandlerTests method testSearchQueryInputDatasetWithHits.
@SuppressWarnings("unchecked")
public void testSearchQueryInputDatasetWithHits() {
searchResponse = mock(SearchResponse.class);
BytesReference bytesArray = new BytesArray("{\"taskId\":\"111\"}");
SearchHit hit = new SearchHit(1);
hit.sourceRef(bytesArray);
SearchHits hits = new SearchHits(new SearchHit[] { hit }, new TotalHits(1L, TotalHits.Relation.EQUAL_TO), 1f);
when(searchResponse.getHits()).thenReturn(hits);
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = (ActionListener<SearchResponse>) invocation.getArguments()[1];
listener.onResponse(searchResponse);
return null;
}).when(client).search(any(), any());
SearchQueryInputDataset searchQueryInputDataset = SearchQueryInputDataset.builder().indices(Collections.singletonList("index1")).searchSourceBuilder(new SearchSourceBuilder().query(QueryBuilders.matchAllQuery())).build();
mlInputDatasetHandler.parseSearchQueryInput(searchQueryInputDataset, listener);
ArgumentCaptor<DataFrame> captor = ArgumentCaptor.forClass(DataFrame.class);
verify(listener, times(1)).onResponse(captor.capture());
Assert.assertEquals(captor.getAllValues().size(), 1);
}
use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.
the class TribuoUtil method generateDatasetWithTarget.
/**
* Generate tribuo dataset from data frame with target.
* @param dataFrame features data
* @param outputFactory the tribuo output factory
* @param desc description for tribuo provenance
* @param outputType the tribuo output type
* @param target target name
* @return tribuo dataset
*/
public static <T extends Output<T>> MutableDataset<T> generateDatasetWithTarget(DataFrame dataFrame, OutputFactory<T> outputFactory, String desc, TribuoOutputType outputType, String target) {
if (StringUtils.isEmpty(target)) {
throw new IllegalArgumentException("Empty target when generating dataset from data frame.");
}
List<Example<T>> dataset = new ArrayList<>();
Tuple<String[], double[][]> featureNamesValues = transformDataFrame(dataFrame);
int targetIndex = -1;
for (int i = 0; i < featureNamesValues.v1().length; ++i) {
if (featureNamesValues.v1()[i].equals(target)) {
targetIndex = i;
break;
}
}
if (targetIndex == -1) {
throw new IllegalArgumentException("No matched target when generating dataset from data frame.");
}
ArrayExample<T> example;
final int finalTargetIndex = targetIndex;
String[] featureNames = IntStream.range(0, featureNamesValues.v1().length).filter(e -> e != finalTargetIndex).mapToObj(e -> featureNamesValues.v1()[e]).toArray(String[]::new);
for (int i = 0; i < dataFrame.size(); ++i) {
switch(outputType) {
case REGRESSOR:
final int finalI = i;
double targetValue = featureNamesValues.v2()[finalI][finalTargetIndex];
double[] featureValues = IntStream.range(0, featureNamesValues.v2()[i].length).filter(e -> e != finalTargetIndex).mapToDouble(e -> featureNamesValues.v2()[finalI][e]).toArray();
example = new ArrayExample<>((T) new Regressor(target, targetValue), featureNames, featureValues);
break;
default:
throw new IllegalArgumentException("unknown type:" + outputType);
}
dataset.add(example);
}
SimpleDataSourceProvenance provenance = new SimpleDataSourceProvenance(desc, outputFactory);
return new MutableDataset<>(new ListDataSource<>(dataset, outputFactory, provenance));
}
use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.
the class KMeansTest method predict.
@Test
public void predict() {
Model model = kMeans.train(trainDataFrame);
MLPredictionOutput output = (MLPredictionOutput) kMeans.predict(predictionDataFrame, model);
DataFrame predictions = output.getPredictionResult();
Assert.assertEquals(predictionSize, predictions.size());
predictions.forEach(row -> Assert.assertTrue(row.getValue(0).intValue() == 0 || row.getValue(0).intValue() == 1));
}
use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.
the class FixedInTimeRandomCutForestTest method constructRCFDataFrame.
private DataFrame constructRCFDataFrame(boolean predict) {
ColumnMeta[] columnMetas = new ColumnMeta[] { new ColumnMeta("timestamp", ColumnType.LONG), new ColumnMeta("value", ColumnType.INTEGER) };
DataFrame dataFrame = new DefaultDataFrame(columnMetas);
long startTime = 1643677200000l;
for (int i = 0; i < dataSize; i++) {
// 1 minute interval
long time = startTime + i * 1000 * 60;
if (predict && i % 100 == 0) {
dataFrame.appendRow(new Object[] { time, ThreadLocalRandom.current().nextInt(100, 1000) });
} else {
dataFrame.appendRow(new Object[] { time, ThreadLocalRandom.current().nextInt(1, 10) });
}
}
return dataFrame;
}
Aggregations