use of org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock in project drools by kiegroup.
the class KiePMMLRegressionModelFactoryTest method getRegressionTablesMap.
@Test
public void getRegressionTablesMap() {
final CompilationDTO<RegressionModel> compilationDTO = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
Map<String, KiePMMLTableSourceCategory> retrieved = KiePMMLRegressionModelFactory.getRegressionTablesMap(RegressionCompilationDTO.fromCompilationDTO(compilationDTO));
// One for classification
int expectedSize = regressionTables.size() + 1;
assertEquals(expectedSize, retrieved.size());
final Collection<KiePMMLTableSourceCategory> values = retrieved.values();
regressionTables.forEach(regressionTable -> assertTrue(values.stream().anyMatch(kiePMMLTableSourceCategory -> kiePMMLTableSourceCategory.getCategory().equals(regressionTable.getTargetCategory()))));
}
use of org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock in project drools by kiegroup.
the class KiePMMLRegressionModelFactoryTest method getKiePMMLRegressionModelClasses.
@Test
public void getKiePMMLRegressionModelClasses() throws IOException, IllegalAccessException, InstantiationException {
final CompilationDTO<RegressionModel> compilationDTO = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
KiePMMLRegressionModel retrieved = KiePMMLRegressionModelFactory.getKiePMMLRegressionModelClasses(RegressionCompilationDTO.fromCompilationDTO(compilationDTO));
assertNotNull(retrieved);
assertEquals(regressionModel.getModelName(), retrieved.getName());
assertEquals(MINING_FUNCTION.byName(regressionModel.getMiningFunction().value()), retrieved.getMiningFunction());
assertEquals(miningFields.get(0).getName().getValue(), retrieved.getTargetField());
final AbstractKiePMMLTable regressionTable = retrieved.getRegressionTable();
assertNotNull(regressionTable);
assertTrue(regressionTable instanceof KiePMMLClassificationTable);
evaluateCategoricalRegressionTable((KiePMMLClassificationTable) regressionTable);
}
use of org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock in project drools by kiegroup.
the class KiePMMLRegressionModelFactoryTest method getKiePMMLRegressionModelSourcesMap.
@Test
public void getKiePMMLRegressionModelSourcesMap() throws IOException {
final CommonCompilationDTO<RegressionModel> compilationDTO = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
Map<String, String> retrieved = KiePMMLRegressionModelFactory.getKiePMMLRegressionModelSourcesMap(RegressionCompilationDTO.fromCompilationDTO(compilationDTO));
assertNotNull(retrieved);
int expectedSize = regressionTables.size() + // One for classification and one for the whole model
2;
assertEquals(expectedSize, retrieved.size());
}
use of org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock in project drools by kiegroup.
the class RegressionModelImplementationProviderTest method getKiePMMLModel.
@Test
public void getKiePMMLModel() throws Exception {
final PMML pmml = TestUtils.loadFromFile(SOURCE_1);
assertNotNull(pmml);
assertEquals(1, pmml.getModels().size());
assertTrue(pmml.getModels().get(0) instanceof RegressionModel);
RegressionModel regressionModel = (RegressionModel) pmml.getModels().get(0);
final CommonCompilationDTO<RegressionModel> compilationDTO = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
final KiePMMLRegressionModel retrieved = PROVIDER.getKiePMMLModel(compilationDTO);
assertNotNull(retrieved);
assertTrue(retrieved instanceof Serializable);
}
use of org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock in project drools by kiegroup.
the class RegressionModelImplementationProviderTest method getKiePMMLModelWithSources.
@Test
public void getKiePMMLModelWithSources() throws Exception {
final PMML pmml = TestUtils.loadFromFile(SOURCE_1);
assertNotNull(pmml);
assertEquals(1, pmml.getModels().size());
assertTrue(pmml.getModels().get(0) instanceof RegressionModel);
RegressionModel regressionModel = (RegressionModel) pmml.getModels().get(0);
final CommonCompilationDTO<RegressionModel> compilationDTO = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
final KiePMMLModelWithSources retrieved = PROVIDER.getKiePMMLModelWithSources(compilationDTO);
assertNotNull(retrieved);
final Map<String, String> sourcesMap = retrieved.getSourcesMap();
assertNotNull(sourcesMap);
assertFalse(sourcesMap.isEmpty());
ClassLoader classLoader = Thread.currentThread().getContextClassLoader();
final Map<String, Class<?>> compiled = KieMemoryCompiler.compile(sourcesMap, classLoader);
for (Class<?> clazz : compiled.values()) {
assertTrue(clazz instanceof Serializable);
}
}
Aggregations