Search in sources :

Example 1 with IntegrationTestBatch

use of org.jpmml.evaluator.IntegrationTestBatch in project jpmml-r by jpmml.

the class ConverterTest method createBatch.

protected ArchiveBatch createBatch(String name, String dataset, Predicate<FieldName> predicate, final Class<? extends Converter<? extends RExp>> clazz) {
    ArchiveBatch result = new IntegrationTestBatch(name, dataset, predicate) {

        @Override
        public IntegrationTest getIntegrationTest() {
            return ConverterTest.this;
        }

        @Override
        public PMML getPMML() throws Exception {
            try (InputStream is = open("/rds/" + getName() + getDataset() + ".rds")) {
                RExpParser parser = new RExpParser(is);
                RExp rexp = parser.parse();
                PMML pmml = convert(rexp, clazz);
                ensureValidity(pmml);
                return pmml;
            }
        }
    };
    return result;
}
Also used : ArchiveBatch(org.jpmml.evaluator.ArchiveBatch) IntegrationTestBatch(org.jpmml.evaluator.IntegrationTestBatch) InputStream(java.io.InputStream) PMML(org.dmg.pmml.PMML)

Example 2 with IntegrationTestBatch

use of org.jpmml.evaluator.IntegrationTestBatch in project jpmml-sparkml by jpmml.

the class ConverterTest method createBatch.

@Override
protected ArchiveBatch createBatch(String name, String dataset, Predicate<FieldName> predicate) {
    Predicate<FieldName> excludePredictionFields = excludeFields(FieldName.create("prediction"), FieldName.create("pmml(prediction)"));
    if (predicate == null) {
        predicate = excludePredictionFields;
    } else {
        predicate = Predicates.and(predicate, excludePredictionFields);
    }
    ArchiveBatch result = new IntegrationTestBatch(name, dataset, predicate) {

        @Override
        public IntegrationTest getIntegrationTest() {
            return ConverterTest.this;
        }

        @Override
        public PMML getPMML() throws Exception {
            StructType schema;
            try (InputStream is = open("/schema/" + getDataset() + ".json")) {
                String json = CharStreams.toString(new InputStreamReader(is, "UTF-8"));
                schema = (StructType) DataType.fromJson(json);
            }
            PipelineModel pipelineModel;
            try (InputStream is = open("/pipeline/" + getName() + getDataset() + ".zip")) {
                File tmpZipFile = File.createTempFile(getName() + getDataset(), ".zip");
                try (OutputStream os = new FileOutputStream(tmpZipFile)) {
                    ByteStreams.copy(is, os);
                }
                File tmpDir = File.createTempFile(getName() + getDataset(), "");
                if (!tmpDir.delete()) {
                    throw new IOException();
                }
                ZipUtil.uncompress(tmpZipFile, tmpDir);
                MLReader<PipelineModel> mlReader = new PipelineModel.PipelineModelReader();
                mlReader.session(ConverterTest.sparkSession);
                pipelineModel = mlReader.load(tmpDir.getAbsolutePath());
            }
            PMML pmml = ConverterUtil.toPMML(schema, pipelineModel);
            ensureValidity(pmml);
            return pmml;
        }
    };
    return result;
}
Also used : IntegrationTestBatch(org.jpmml.evaluator.IntegrationTestBatch) StructType(org.apache.spark.sql.types.StructType) InputStreamReader(java.io.InputStreamReader) InputStream(java.io.InputStream) OutputStream(java.io.OutputStream) FileOutputStream(java.io.FileOutputStream) IOException(java.io.IOException) PipelineModel(org.apache.spark.ml.PipelineModel) ArchiveBatch(org.jpmml.evaluator.ArchiveBatch) FileOutputStream(java.io.FileOutputStream) PMML(org.dmg.pmml.PMML) FieldName(org.dmg.pmml.FieldName) File(java.io.File)

Aggregations

InputStream (java.io.InputStream)2 PMML (org.dmg.pmml.PMML)2 ArchiveBatch (org.jpmml.evaluator.ArchiveBatch)2 IntegrationTestBatch (org.jpmml.evaluator.IntegrationTestBatch)2 File (java.io.File)1 FileOutputStream (java.io.FileOutputStream)1 IOException (java.io.IOException)1 InputStreamReader (java.io.InputStreamReader)1 OutputStream (java.io.OutputStream)1 PipelineModel (org.apache.spark.ml.PipelineModel)1 StructType (org.apache.spark.sql.types.StructType)1 FieldName (org.dmg.pmml.FieldName)1