Search in sources :

Example 16 with DataFrame

use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.

the class MLEngine method validateMLInput.

private static void validateMLInput(Input input) {
    validateInput(input);
    if (!(input instanceof MLInput)) {
        throw new IllegalArgumentException("Input should be MLInput");
    }
    MLInput mlInput = (MLInput) input;
    DataFrame dataFrame = mlInput.getDataFrame();
    if (dataFrame == null || dataFrame.size() == 0) {
        throw new IllegalArgumentException("Input data frame should not be null or empty");
    }
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput) DataFrame(org.opensearch.ml.common.dataframe.DataFrame)

Example 17 with DataFrame

use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.

the class MLPredictionTaskRequestTest method writeTo_Success.

@Test
public void writeTo_Success() throws IOException {
    MLPredictionTaskRequest request = MLPredictionTaskRequest.builder().mlInput(mlInput).build();
    BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
    request.writeTo(bytesStreamOutput);
    request = new MLPredictionTaskRequest(bytesStreamOutput.bytes().streamInput());
    assertEquals(FunctionName.KMEANS, request.getMlInput().getAlgorithm());
    KMeansParams params = (KMeansParams) request.getMlInput().getParameters();
    assertEquals(1, params.getCentroids().intValue());
    MLInputDataset inputDataset = request.getMlInput().getInputDataset();
    assertEquals(MLInputDataType.DATA_FRAME, inputDataset.getInputDataType());
    DataFrame dataFrame = ((DataFrameInputDataset) inputDataset).getDataFrame();
    assertEquals(1, dataFrame.size());
    assertEquals(1, dataFrame.columnMetas().length);
    assertEquals("key1", dataFrame.columnMetas()[0].getName());
    assertEquals(ColumnType.DOUBLE, dataFrame.columnMetas()[0].getColumnType());
    assertEquals(1, dataFrame.getRow(0).size());
    assertEquals(2.00, dataFrame.getRow(0).getValue(0).getValue());
    assertNull(request.getModelId());
}
Also used : KMeansParams(org.opensearch.ml.common.parameter.KMeansParams) DataFrameInputDataset(org.opensearch.ml.common.dataset.DataFrameInputDataset) MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) BytesStreamOutput(org.opensearch.common.io.stream.BytesStreamOutput) Test(org.junit.Test)

Example 18 with DataFrame

use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.

the class KMeansTest method trainAndPredict.

@Test
public void trainAndPredict() {
    KMeansParams parameters = KMeansParams.builder().distanceType(KMeansParams.DistanceType.EUCLIDEAN).iterations(10).centroids(2).build();
    KMeans kMeans = new KMeans(parameters);
    MLPredictionOutput output = (MLPredictionOutput) kMeans.trainAndPredict(trainDataFrame);
    DataFrame predictions = output.getPredictionResult();
    Assert.assertEquals(trainSize, predictions.size());
    predictions.forEach(row -> Assert.assertTrue(row.getValue(0).intValue() == 0 || row.getValue(0).intValue() == 1));
    parameters = parameters.toBuilder().distanceType(KMeansParams.DistanceType.COSINE).build();
    kMeans = new KMeans(parameters);
    output = (MLPredictionOutput) kMeans.trainAndPredict(trainDataFrame);
    predictions = output.getPredictionResult();
    Assert.assertEquals(trainSize, predictions.size());
    parameters = parameters.toBuilder().distanceType(KMeansParams.DistanceType.L1).build();
    kMeans = new KMeans(parameters);
    output = (MLPredictionOutput) kMeans.trainAndPredict(trainDataFrame);
    predictions = output.getPredictionResult();
    Assert.assertEquals(trainSize, predictions.size());
}
Also used : KMeansParams(org.opensearch.ml.common.parameter.KMeansParams) MLPredictionOutput(org.opensearch.ml.common.parameter.MLPredictionOutput) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) KMeansHelper.constructKMeansDataFrame(org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame) Test(org.junit.Test)

Example 19 with DataFrame

use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.

the class MLInputTest method setUp.

@Before
public void setUp() throws Exception {
    final ColumnMeta[] columnMetas = new ColumnMeta[] { new ColumnMeta("test", ColumnType.DOUBLE) };
    List<Row> rows = new ArrayList<>();
    rows.add(new Row(new ColumnValue[] { new DoubleValue(1.0) }));
    rows.add(new Row(new ColumnValue[] { new DoubleValue(2.0) }));
    rows.add(new Row(new ColumnValue[] { new DoubleValue(3.0) }));
    DataFrame dataFrame = new DefaultDataFrame(columnMetas, rows);
    input = MLInput.builder().algorithm(algorithm).parameters(LinearRegressionParams.builder().learningRate(0.1).build()).inputDataset(DataFrameInputDataset.builder().dataFrame(dataFrame).build()).build();
}
Also used : ColumnMeta(org.opensearch.ml.common.dataframe.ColumnMeta) DoubleValue(org.opensearch.ml.common.dataframe.DoubleValue) ArrayList(java.util.ArrayList) ColumnValue(org.opensearch.ml.common.dataframe.ColumnValue) Row(org.opensearch.ml.common.dataframe.Row) DataFrame(org.opensearch.ml.common.dataframe.DataFrame) DefaultDataFrame(org.opensearch.ml.common.dataframe.DefaultDataFrame) DefaultDataFrame(org.opensearch.ml.common.dataframe.DefaultDataFrame) Before(org.junit.Before)

Example 20 with DataFrame

use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.

the class MLTrainingTaskRunner method startTrainingTask.

/**
 * Start training task
 * @param mlTask ML task
 * @param mlInput ML input
 * @param listener Action listener
 */
public void startTrainingTask(MLTask mlTask, MLInput mlInput, ActionListener<MLTaskResponse> listener) {
    ActionListener<MLTaskResponse> internalListener = wrappedCleanupListener(listener, mlTask.getTaskId());
    // track ML task count and add ML task into cache
    mlStats.getStat(ML_EXECUTING_TASK_COUNT).increment();
    mlStats.getStat(ML_TOTAL_REQUEST_COUNT).increment();
    mlStats.createCounterStatIfAbsent(requestCountStat(mlTask.getFunctionName(), ActionName.TRAIN)).increment();
    mlTaskManager.add(mlTask);
    try {
        if (mlInput.getInputDataset().getInputDataType().equals(MLInputDataType.SEARCH_QUERY)) {
            ActionListener<DataFrame> dataFrameActionListener = ActionListener.wrap(dataFrame -> {
                train(mlTask, mlInput.toBuilder().inputDataset(new DataFrameInputDataset(dataFrame)).build(), internalListener);
            }, e -> {
                log.error("Failed to generate DataFrame from search query", e);
                internalListener.onFailure(e);
            });
            mlInputDatasetHandler.parseSearchQueryInput(mlInput.getInputDataset(), new ThreadedActionListener<>(log, threadPool, TASK_THREAD_POOL, dataFrameActionListener, false));
        } else {
            threadPool.executor(TASK_THREAD_POOL).execute(() -> {
                train(mlTask, mlInput, internalListener);
            });
        }
    } catch (Exception e) {
        log.error("Failed to train " + mlInput.getAlgorithm(), e);
        internalListener.onFailure(e);
    }
}
Also used : MLTaskResponse(org.opensearch.ml.common.transport.MLTaskResponse) DataFrameInputDataset(org.opensearch.ml.common.dataset.DataFrameInputDataset) DataFrame(org.opensearch.ml.common.dataframe.DataFrame)

Aggregations

DataFrame (org.opensearch.ml.common.dataframe.DataFrame)34 ColumnMeta (org.opensearch.ml.common.dataframe.ColumnMeta)10 DefaultDataFrame (org.opensearch.ml.common.dataframe.DefaultDataFrame)10 MLPredictionOutput (org.opensearch.ml.common.parameter.MLPredictionOutput)10 MLInput (org.opensearch.ml.common.parameter.MLInput)9 ArrayList (java.util.ArrayList)8 Test (org.junit.Test)8 Model (org.opensearch.ml.common.parameter.Model)8 Row (org.opensearch.ml.common.dataframe.Row)7 DataFrameInputDataset (org.opensearch.ml.common.dataset.DataFrameInputDataset)7 MLInputDataset (org.opensearch.ml.common.dataset.MLInputDataset)7 KMeansHelper.constructKMeansDataFrame (org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame)7 HashMap (java.util.HashMap)6 ColumnValue (org.opensearch.ml.common.dataframe.ColumnValue)6 LinearRegressionHelper.constructLinearRegressionPredictionDataFrame (org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame)5 LinearRegressionHelper.constructLinearRegressionTrainDataFrame (org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame)5 List (java.util.List)4 Before (org.junit.Before)4 Input (org.opensearch.ml.common.parameter.Input)4 LocalSampleCalculatorInput (org.opensearch.ml.common.parameter.LocalSampleCalculatorInput)4