Search in sources :

Example 21 with MLInput

use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.

the class MLEngineTest method trainKMeansModel.

private Model trainKMeansModel() {
    KMeansParams parameters = KMeansParams.builder().centroids(2).iterations(10).distanceType(KMeansParams.DistanceType.EUCLIDEAN).build();
    DataFrame trainDataFrame = constructKMeansDataFrame(100);
    MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(trainDataFrame).build();
    Input mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).parameters(parameters).inputDataset(inputDataset).build();
    return MLEngine.train(mlInput);
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput) Input(org.opensearch.ml.common.parameter.Input) LocalSampleCalculatorInput(org.opensearch.ml.common.parameter.LocalSampleCalculatorInput) KMeansParams(org.opensearch.ml.common.parameter.KMeansParams) MLInputDataset(org.opensearch.ml.common.dataset.MLInputDataset) LinearRegressionHelper.constructLinearRegressionPredictionDataFrame(org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame) KMeansHelper.constructKMeansDataFrame(org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame) LinearRegressionHelper.constructLinearRegressionTrainDataFrame(org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame) DataFrame(org.opensearch.ml.common.dataframe.DataFrame)

Example 22 with MLInput

use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.

the class RestMLPredictionAction method getRequest.

/**
 * Creates a MLPredictionTaskRequest from a RestRequest
 *
 * @param request RestRequest
 * @return MLPredictionTaskRequest
 */
@VisibleForTesting
MLPredictionTaskRequest getRequest(RestRequest request) throws IOException {
    String algorithm = getAlgorithm(request);
    String modelId = getParameterId(request, PARAMETER_MODEL_ID);
    XContentParser parser = request.contentParser();
    ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
    MLInput mlInput = MLInput.parse(parser, algorithm);
    return new MLPredictionTaskRequest(modelId, mlInput);
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput) MLPredictionTaskRequest(org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest) XContentParser(org.opensearch.common.xcontent.XContentParser) VisibleForTesting(com.google.common.annotations.VisibleForTesting)

Example 23 with MLInput

use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.

the class MLEngine method train.

public static Model train(Input input) {
    validateMLInput(input);
    MLInput mlInput = (MLInput) input;
    Trainable trainable = MLEngineClassLoader.initInstance(mlInput.getAlgorithm(), mlInput.getParameters(), MLAlgoParams.class);
    if (trainable == null) {
        throw new IllegalArgumentException("Unsupported algorithm: " + mlInput.getAlgorithm());
    }
    return trainable.train(mlInput.getDataFrame());
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput)

Example 24 with MLInput

use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.

the class MLEngine method predict.

public static MLOutput predict(Input input, Model model) {
    validateMLInput(input);
    MLInput mlInput = (MLInput) input;
    Predictable predictable = MLEngineClassLoader.initInstance(mlInput.getAlgorithm(), mlInput.getParameters(), MLAlgoParams.class);
    if (predictable == null) {
        throw new IllegalArgumentException("Unsupported algorithm: " + mlInput.getAlgorithm());
    }
    return predictable.predict(mlInput.getDataFrame(), model);
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput)

Example 25 with MLInput

use of org.opensearch.ml.common.parameter.MLInput in project ml-commons by opensearch-project.

the class MLEngine method trainAndPredict.

public static MLOutput trainAndPredict(Input input) {
    validateMLInput(input);
    MLInput mlInput = (MLInput) input;
    TrainAndPredictable trainAndPredictable = MLEngineClassLoader.initInstance(mlInput.getAlgorithm(), mlInput.getParameters(), MLAlgoParams.class);
    if (trainAndPredictable == null) {
        throw new IllegalArgumentException("Unsupported algorithm: " + mlInput.getAlgorithm());
    }
    return trainAndPredictable.trainAndPredict(mlInput.getDataFrame());
}
Also used : MLInput(org.opensearch.ml.common.parameter.MLInput)

Aggregations

MLInput (org.opensearch.ml.common.parameter.MLInput)46 Test (org.junit.Test)18 MLInputDataset (org.opensearch.ml.common.dataset.MLInputDataset)13 MLTaskResponse (org.opensearch.ml.common.transport.MLTaskResponse)12 DataFrameInputDataset (org.opensearch.ml.common.dataset.DataFrameInputDataset)11 DataFrame (org.opensearch.ml.common.dataframe.DataFrame)10 Input (org.opensearch.ml.common.parameter.Input)9 LocalSampleCalculatorInput (org.opensearch.ml.common.parameter.LocalSampleCalculatorInput)9 MLPredictionOutput (org.opensearch.ml.common.parameter.MLPredictionOutput)7 MLPredictionTaskRequest (org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest)7 MLTrainingTaskRequest (org.opensearch.ml.common.transport.training.MLTrainingTaskRequest)7 MLOutput (org.opensearch.ml.common.parameter.MLOutput)6 XContentParser (org.opensearch.common.xcontent.XContentParser)5 Response (org.opensearch.client.Response)4 Model (org.opensearch.ml.common.parameter.Model)4 KMeansHelper.constructKMeansDataFrame (org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame)4 LinearRegressionHelper.constructLinearRegressionPredictionDataFrame (org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame)4 LinearRegressionHelper.constructLinearRegressionTrainDataFrame (org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame)4 VisibleForTesting (com.google.common.annotations.VisibleForTesting)3 Instant (java.time.Instant)3