Search in sources :

Example 6 with DataFrame

use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.

the class MLInputTests method testParseKmeansInputDataFrame.

public void testParseKmeansInputDataFrame() throws IOException {
    String query = "{\"input_data\":{\"column_metas\":[{\"name\":\"total_sum\",\"column_type\":\"DOUBLE\"},{\"name\":\"is_error\"," + "\"column_type\":\"BOOLEAN\"}],\"rows\":[{\"values\":[{\"column_type\":\"DOUBLE\",\"value\":15}," + "{\"column_type\":\"BOOLEAN\",\"value\":false}]},{\"values\":[{\"column_type\":\"DOUBLE\",\"value\":100}," + "{\"column_type\":\"BOOLEAN\",\"value\":true}]}]}}";
    XContentParser parser = parser(query);
    MLInput mlInput = MLInput.parse(parser, FunctionName.KMEANS.name());
    DataFrameInputDataset inputDataset = (DataFrameInputDataset) mlInput.getInputDataset();
    DataFrame dataFrame = inputDataset.getDataFrame();
    assertEquals(2, dataFrame.columnMetas().length);
    assertEquals(ColumnType.DOUBLE, dataFrame.columnMetas()[0].getColumnType());
    assertEquals(ColumnType.BOOLEAN, dataFrame.columnMetas()[1].getColumnType());
    assertEquals("total_sum", dataFrame.columnMetas()[0].getName());
    assertEquals("is_error", dataFrame.columnMetas()[1].getName());
    assertEquals(ColumnType.DOUBLE, dataFrame.getRow(0).getValue(0).columnType());
    assertEquals(ColumnType.BOOLEAN, dataFrame.getRow(0).getValue(1).columnType());
    assertEquals(15.0, dataFrame.getRow(0).getValue(0).getValue());
    assertEquals(false, dataFrame.getRow(0).getValue(1).getValue());
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput) DataFrameInputDataset(org.opensearch.ml.common.dataset.DataFrameInputDataset) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) XContentParser(org.opensearch.common.xcontent.XContentParser)

Example 7 with DataFrame

use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.

the class MLInputDatasetHandlerTests method testSearchQueryInputDatasetWithHits.

@SuppressWarnings("unchecked")
public void testSearchQueryInputDatasetWithHits() {
    searchResponse = mock(SearchResponse.class);
    BytesReference bytesArray = new BytesArray("{\"taskId\":\"111\"}");
    SearchHit hit = new SearchHit(1);
    hit.sourceRef(bytesArray);
    SearchHits hits = new SearchHits(new SearchHit[] { hit }, new TotalHits(1L, TotalHits.Relation.EQUAL_TO), 1f);
    when(searchResponse.getHits()).thenReturn(hits);
    doAnswer(invocation -> {
        ActionListener<SearchResponse> listener = (ActionListener<SearchResponse>) invocation.getArguments()[1];
        listener.onResponse(searchResponse);
        return null;
    }).when(client).search(any(), any());
    SearchQueryInputDataset searchQueryInputDataset = SearchQueryInputDataset.builder().indices(Collections.singletonList("index1")).searchSourceBuilder(new SearchSourceBuilder().query(QueryBuilders.matchAllQuery())).build();
    mlInputDatasetHandler.parseSearchQueryInput(searchQueryInputDataset, listener);
    ArgumentCaptor<DataFrame> captor = ArgumentCaptor.forClass(DataFrame.class);
    verify(listener, times(1)).onResponse(captor.capture());
    Assert.assertEquals(captor.getAllValues().size(), 1);
}
Also used : BytesReference(org.opensearch.common.bytes.BytesReference) TotalHits(org.apache.lucene.search.TotalHits) BytesArray(org.opensearch.common.bytes.BytesArray) SearchQueryInputDataset(org.opensearch.ml.common.dataset.SearchQueryInputDataset) SearchHit(org.opensearch.search.SearchHit) ActionListener(org.opensearch.action.ActionListener) SearchHits(org.opensearch.search.SearchHits) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) SearchResponse(org.opensearch.action.search.SearchResponse) SearchSourceBuilder(org.opensearch.search.builder.SearchSourceBuilder)

Example 8 with DataFrame

use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.

the class TribuoUtil method generateDatasetWithTarget.

/**
 * Generate tribuo dataset from data frame with target.
 * @param dataFrame features data
 * @param outputFactory the tribuo output factory
 * @param desc description for tribuo provenance
 * @param outputType the tribuo output type
 * @param target target name
 * @return tribuo dataset
 */
public static <T extends Output<T>> MutableDataset<T> generateDatasetWithTarget(DataFrame dataFrame, OutputFactory<T> outputFactory, String desc, TribuoOutputType outputType, String target) {
    if (StringUtils.isEmpty(target)) {
        throw new IllegalArgumentException("Empty target when generating dataset from data frame.");
    }
    List<Example<T>> dataset = new ArrayList<>();
    Tuple<String[], double[][]> featureNamesValues = transformDataFrame(dataFrame);
    int targetIndex = -1;
    for (int i = 0; i < featureNamesValues.v1().length; ++i) {
        if (featureNamesValues.v1()[i].equals(target)) {
            targetIndex = i;
            break;
        }
    }
    if (targetIndex == -1) {
        throw new IllegalArgumentException("No matched target when generating dataset from data frame.");
    }
    ArrayExample<T> example;
    final int finalTargetIndex = targetIndex;
    String[] featureNames = IntStream.range(0, featureNamesValues.v1().length).filter(e -> e != finalTargetIndex).mapToObj(e -> featureNamesValues.v1()[e]).toArray(String[]::new);
    for (int i = 0; i < dataFrame.size(); ++i) {
        switch(outputType) {
            case REGRESSOR:
                final int finalI = i;
                double targetValue = featureNamesValues.v2()[finalI][finalTargetIndex];
                double[] featureValues = IntStream.range(0, featureNamesValues.v2()[i].length).filter(e -> e != finalTargetIndex).mapToDouble(e -> featureNamesValues.v2()[finalI][e]).toArray();
                example = new ArrayExample<>((T) new Regressor(target, targetValue), featureNames, featureValues);
                break;
            default:
                throw new IllegalArgumentException("unknown type:" + outputType);
        }
        dataset.add(example);
    }
    SimpleDataSourceProvenance provenance = new SimpleDataSourceProvenance(desc, outputFactory);
    return new MutableDataset<>(new ListDataSource<>(dataset, outputFactory, provenance));
}
Also used : IntStream(java.util.stream.IntStream) Example(org.tribuo.Example) Arrays(java.util.Arrays) Row(org.opensearch.ml.common.dataframe.Row) Iterator(java.util.Iterator) ColumnValue(org.opensearch.ml.common.dataframe.ColumnValue) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) ClusterID(org.tribuo.clustering.ClusterID) StringUtils(org.apache.commons.lang3.StringUtils) OutputFactory(org.tribuo.OutputFactory) Event(org.tribuo.anomaly.Event) Tuple(org.opensearch.common.collect.Tuple) ArrayList(java.util.ArrayList) UtilityClass(lombok.experimental.UtilityClass) ArrayExample(org.tribuo.impl.ArrayExample) Regressor(org.tribuo.regression.Regressor) List(java.util.List) TribuoOutputType(org.opensearch.ml.engine.contants.TribuoOutputType) Output(org.tribuo.Output) ListDataSource(org.tribuo.datasource.ListDataSource) SimpleDataSourceProvenance(org.tribuo.provenance.SimpleDataSourceProvenance) StreamSupport(java.util.stream.StreamSupport) ColumnMeta(org.opensearch.ml.common.dataframe.ColumnMeta) MutableDataset(org.tribuo.MutableDataset) SimpleDataSourceProvenance(org.tribuo.provenance.SimpleDataSourceProvenance) ArrayList(java.util.ArrayList) Example(org.tribuo.Example) ArrayExample(org.tribuo.impl.ArrayExample) Regressor(org.tribuo.regression.Regressor) MutableDataset(org.tribuo.MutableDataset)

Example 9 with DataFrame

use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.

the class KMeansTest method predict.

@Test
public void predict() {
    Model model = kMeans.train(trainDataFrame);
    MLPredictionOutput output = (MLPredictionOutput) kMeans.predict(predictionDataFrame, model);
    DataFrame predictions = output.getPredictionResult();
    Assert.assertEquals(predictionSize, predictions.size());
    predictions.forEach(row -> Assert.assertTrue(row.getValue(0).intValue() == 0 || row.getValue(0).intValue() == 1));
}
Also used : Model(org.opensearch.ml.common.parameter.Model) MLPredictionOutput(org.opensearch.ml.common.parameter.MLPredictionOutput) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) KMeansHelper.constructKMeansDataFrame(org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame) Test(org.junit.Test)

Example 10 with DataFrame

use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.

the class FixedInTimeRandomCutForestTest method constructRCFDataFrame.

private DataFrame constructRCFDataFrame(boolean predict) {
    ColumnMeta[] columnMetas = new ColumnMeta[] { new ColumnMeta("timestamp", ColumnType.LONG), new ColumnMeta("value", ColumnType.INTEGER) };
    DataFrame dataFrame = new DefaultDataFrame(columnMetas);
    long startTime = 1643677200000l;
    for (int i = 0; i < dataSize; i++) {
        // 1 minute interval
        long time = startTime + i * 1000 * 60;
        if (predict && i % 100 == 0) {
            dataFrame.appendRow(new Object[] { time, ThreadLocalRandom.current().nextInt(100, 1000) });
        } else {
            dataFrame.appendRow(new Object[] { time, ThreadLocalRandom.current().nextInt(1, 10) });
        }
    }
    return dataFrame;
}
Also used : ColumnMeta(org.opensearch.ml.common.dataframe.ColumnMeta) DefaultDataFrame(org.opensearch.ml.common.dataframe.DefaultDataFrame) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) DefaultDataFrame(org.opensearch.ml.common.dataframe.DefaultDataFrame)

Aggregations

DataFrame (org.opensearch.ml.common.dataframe.DataFrame)34 ColumnMeta (org.opensearch.ml.common.dataframe.ColumnMeta)10 DefaultDataFrame (org.opensearch.ml.common.dataframe.DefaultDataFrame)10 MLPredictionOutput (org.opensearch.ml.common.parameter.MLPredictionOutput)10 MLInput (org.opensearch.ml.common.parameter.MLInput)9 ArrayList (java.util.ArrayList)8 Test (org.junit.Test)8 Model (org.opensearch.ml.common.parameter.Model)8 Row (org.opensearch.ml.common.dataframe.Row)7 DataFrameInputDataset (org.opensearch.ml.common.dataset.DataFrameInputDataset)7 MLInputDataset (org.opensearch.ml.common.dataset.MLInputDataset)7 KMeansHelper.constructKMeansDataFrame (org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame)7 HashMap (java.util.HashMap)6 ColumnValue (org.opensearch.ml.common.dataframe.ColumnValue)6 LinearRegressionHelper.constructLinearRegressionPredictionDataFrame (org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame)5 LinearRegressionHelper.constructLinearRegressionTrainDataFrame (org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame)5 List (java.util.List)4 Before (org.junit.Before)4 Input (org.opensearch.ml.common.parameter.Input)4 LocalSampleCalculatorInput (org.opensearch.ml.common.parameter.LocalSampleCalculatorInput)4