use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.
the class MLEngine method validateMLInput.
private static void validateMLInput(Input input) {
validateInput(input);
if (!(input instanceof MLInput)) {
throw new IllegalArgumentException("Input should be MLInput");
}
MLInput mlInput = (MLInput) input;
DataFrame dataFrame = mlInput.getDataFrame();
if (dataFrame == null || dataFrame.size() == 0) {
throw new IllegalArgumentException("Input data frame should not be null or empty");
}
}
use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.
the class MLPredictionTaskRequestTest method writeTo_Success.
@Test
public void writeTo_Success() throws IOException {
MLPredictionTaskRequest request = MLPredictionTaskRequest.builder().mlInput(mlInput).build();
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
request.writeTo(bytesStreamOutput);
request = new MLPredictionTaskRequest(bytesStreamOutput.bytes().streamInput());
assertEquals(FunctionName.KMEANS, request.getMlInput().getAlgorithm());
KMeansParams params = (KMeansParams) request.getMlInput().getParameters();
assertEquals(1, params.getCentroids().intValue());
MLInputDataset inputDataset = request.getMlInput().getInputDataset();
assertEquals(MLInputDataType.DATA_FRAME, inputDataset.getInputDataType());
DataFrame dataFrame = ((DataFrameInputDataset) inputDataset).getDataFrame();
assertEquals(1, dataFrame.size());
assertEquals(1, dataFrame.columnMetas().length);
assertEquals("key1", dataFrame.columnMetas()[0].getName());
assertEquals(ColumnType.DOUBLE, dataFrame.columnMetas()[0].getColumnType());
assertEquals(1, dataFrame.getRow(0).size());
assertEquals(2.00, dataFrame.getRow(0).getValue(0).getValue());
assertNull(request.getModelId());
}
use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.
the class KMeansTest method trainAndPredict.
@Test
public void trainAndPredict() {
KMeansParams parameters = KMeansParams.builder().distanceType(KMeansParams.DistanceType.EUCLIDEAN).iterations(10).centroids(2).build();
KMeans kMeans = new KMeans(parameters);
MLPredictionOutput output = (MLPredictionOutput) kMeans.trainAndPredict(trainDataFrame);
DataFrame predictions = output.getPredictionResult();
Assert.assertEquals(trainSize, predictions.size());
predictions.forEach(row -> Assert.assertTrue(row.getValue(0).intValue() == 0 || row.getValue(0).intValue() == 1));
parameters = parameters.toBuilder().distanceType(KMeansParams.DistanceType.COSINE).build();
kMeans = new KMeans(parameters);
output = (MLPredictionOutput) kMeans.trainAndPredict(trainDataFrame);
predictions = output.getPredictionResult();
Assert.assertEquals(trainSize, predictions.size());
parameters = parameters.toBuilder().distanceType(KMeansParams.DistanceType.L1).build();
kMeans = new KMeans(parameters);
output = (MLPredictionOutput) kMeans.trainAndPredict(trainDataFrame);
predictions = output.getPredictionResult();
Assert.assertEquals(trainSize, predictions.size());
}
use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.
the class MLInputTest method setUp.
@Before
public void setUp() throws Exception {
final ColumnMeta[] columnMetas = new ColumnMeta[] { new ColumnMeta("test", ColumnType.DOUBLE) };
List<Row> rows = new ArrayList<>();
rows.add(new Row(new ColumnValue[] { new DoubleValue(1.0) }));
rows.add(new Row(new ColumnValue[] { new DoubleValue(2.0) }));
rows.add(new Row(new ColumnValue[] { new DoubleValue(3.0) }));
DataFrame dataFrame = new DefaultDataFrame(columnMetas, rows);
input = MLInput.builder().algorithm(algorithm).parameters(LinearRegressionParams.builder().learningRate(0.1).build()).inputDataset(DataFrameInputDataset.builder().dataFrame(dataFrame).build()).build();
}
use of org.opensearch.ml.common.dataframe.DataFrame in project ml-commons by opensearch-project.
the class MLTrainingTaskRunner method startTrainingTask.
/**
* Start training task
* @param mlTask ML task
* @param mlInput ML input
* @param listener Action listener
*/
public void startTrainingTask(MLTask mlTask, MLInput mlInput, ActionListener<MLTaskResponse> listener) {
ActionListener<MLTaskResponse> internalListener = wrappedCleanupListener(listener, mlTask.getTaskId());
// track ML task count and add ML task into cache
mlStats.getStat(ML_EXECUTING_TASK_COUNT).increment();
mlStats.getStat(ML_TOTAL_REQUEST_COUNT).increment();
mlStats.createCounterStatIfAbsent(requestCountStat(mlTask.getFunctionName(), ActionName.TRAIN)).increment();
mlTaskManager.add(mlTask);
try {
if (mlInput.getInputDataset().getInputDataType().equals(MLInputDataType.SEARCH_QUERY)) {
ActionListener<DataFrame> dataFrameActionListener = ActionListener.wrap(dataFrame -> {
train(mlTask, mlInput.toBuilder().inputDataset(new DataFrameInputDataset(dataFrame)).build(), internalListener);
}, e -> {
log.error("Failed to generate DataFrame from search query", e);
internalListener.onFailure(e);
});
mlInputDatasetHandler.parseSearchQueryInput(mlInput.getInputDataset(), new ThreadedActionListener<>(log, threadPool, TASK_THREAD_POOL, dataFrameActionListener, false));
} else {
threadPool.executor(TASK_THREAD_POOL).execute(() -> {
train(mlTask, mlInput, internalListener);
});
}
} catch (Exception e) {
log.error("Failed to train " + mlInput.getAlgorithm(), e);
internalListener.onFailure(e);
}
}
Aggregations