use of org.opensearch.ml.common.dataset.MLInputDataset in project ml-commons by opensearch-project.
the class MLEngineTest method predictWithoutAlgoName.
@Test
public void predictWithoutAlgoName() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("algorithm can't be null");
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructKMeansDataFrame(10)).build();
Input mlInput = MLInput.builder().inputDataset(inputDataset).build();
MLEngine.predict(mlInput, null);
}
use of org.opensearch.ml.common.dataset.MLInputDataset in project ml-commons by opensearch-project.
the class MLEngineTest method train_EmptyDataFrame.
@Test
public void train_EmptyDataFrame() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Input data frame should not be null or empty");
FunctionName algoName = FunctionName.LINEAR_REGRESSION;
try (MockedStatic<MLEngineClassLoader> loader = Mockito.mockStatic(MLEngineClassLoader.class)) {
loader.when(() -> MLEngineClassLoader.initInstance(algoName, null, MLAlgoParams.class)).thenReturn(null);
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructKMeansDataFrame(0)).build();
MLEngine.train(MLInput.builder().algorithm(algoName).inputDataset(inputDataset).build());
}
}
use of org.opensearch.ml.common.dataset.MLInputDataset in project ml-commons by opensearch-project.
the class MLEngineTest method predictUnsupportedAlgorithm.
@Test
public void predictUnsupportedAlgorithm() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Unsupported algorithm: LINEAR_REGRESSION");
FunctionName algoName = FunctionName.LINEAR_REGRESSION;
try (MockedStatic<MLEngineClassLoader> loader = Mockito.mockStatic(MLEngineClassLoader.class)) {
loader.when(() -> MLEngineClassLoader.initInstance(algoName, null, MLAlgoParams.class)).thenReturn(null);
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructLinearRegressionPredictionDataFrame()).build();
Input mlInput = MLInput.builder().algorithm(algoName).inputDataset(inputDataset).build();
MLEngine.predict(mlInput, null);
}
}
use of org.opensearch.ml.common.dataset.MLInputDataset in project ml-commons by opensearch-project.
the class MLEngineTest method trainKMeansModel.
private Model trainKMeansModel() {
KMeansParams parameters = KMeansParams.builder().centroids(2).iterations(10).distanceType(KMeansParams.DistanceType.EUCLIDEAN).build();
DataFrame trainDataFrame = constructKMeansDataFrame(100);
MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(trainDataFrame).build();
Input mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).parameters(parameters).inputDataset(inputDataset).build();
return MLEngine.train(mlInput);
}
use of org.opensearch.ml.common.dataset.MLInputDataset in project ml-commons by opensearch-project.
the class MLInputDatasetHandler method parseSearchQueryInput.
/**
* Create DataFrame based on given search query
* @param mlInputDataset MLInputDataset
* @param listener ActionListener
*/
public void parseSearchQueryInput(MLInputDataset mlInputDataset, ActionListener<DataFrame> listener) {
if (!mlInputDataset.getInputDataType().equals(MLInputDataType.SEARCH_QUERY)) {
throw new IllegalArgumentException("Input dataset is not SEARCH_QUERY type.");
}
SearchQueryInputDataset inputDataset = (SearchQueryInputDataset) mlInputDataset;
SearchRequest searchRequest = new SearchRequest();
searchRequest.source(inputDataset.getSearchSourceBuilder());
List<String> indicesList = inputDataset.getIndices();
String[] indices = new String[indicesList.size()];
indices = indicesList.toArray(indices);
searchRequest.indices(indices);
client.search(searchRequest, ActionListener.wrap(r -> {
if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) {
listener.onFailure(new IllegalArgumentException("No document found"));
return;
}
SearchHits hits = r.getHits();
List<Map<String, Object>> input = new ArrayList<>();
SearchHit[] searchHits = hits.getHits();
for (SearchHit hit : searchHits) {
input.add(hit.getSourceAsMap());
}
DataFrame dataFrame = DataFrameBuilder.load(input);
listener.onResponse(dataFrame);
return;
}, e -> {
log.error("Failed to search" + e);
listener.onFailure(e);
}));
return;
}
Aggregations