use of org.jpmml.converter.Feature in project jpmml-sparkml by jpmml.
the class ModelConverter method encodeSchema.
public Schema encodeSchema(SparkMLEncoder encoder) {
T model = getTransformer();
Label label = null;
if (model instanceof HasLabelCol) {
HasLabelCol hasLabelCol = (HasLabelCol) model;
String labelCol = hasLabelCol.getLabelCol();
Feature feature = encoder.getOnlyFeature(labelCol);
MiningFunction miningFunction = getMiningFunction();
switch(miningFunction) {
case CLASSIFICATION:
{
if (feature instanceof CategoricalFeature) {
CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
DataField dataField = encoder.getDataField(categoricalFeature.getName());
label = new CategoricalLabel(dataField);
} else if (feature instanceof ContinuousFeature) {
ContinuousFeature continuousFeature = (ContinuousFeature) feature;
int numClasses = 2;
if (model instanceof ClassificationModel) {
ClassificationModel<?, ?> classificationModel = (ClassificationModel<?, ?>) model;
numClasses = classificationModel.numClasses();
}
List<String> categories = new ArrayList<>();
for (int i = 0; i < numClasses; i++) {
categories.add(String.valueOf(i));
}
Field<?> field = encoder.toCategorical(continuousFeature.getName(), categories);
encoder.putOnlyFeature(labelCol, new CategoricalFeature(encoder, field, categories));
label = new CategoricalLabel(field.getName(), field.getDataType(), categories);
} else {
throw new IllegalArgumentException("Expected a categorical or categorical-like continuous feature, got " + feature);
}
}
break;
case REGRESSION:
{
Field<?> field = encoder.toContinuous(feature.getName());
field.setDataType(DataType.DOUBLE);
label = new ContinuousLabel(field.getName(), field.getDataType());
}
break;
default:
throw new IllegalArgumentException("Mining function " + miningFunction + " is not supported");
}
}
if (model instanceof ClassificationModel) {
ClassificationModel<?, ?> classificationModel = (ClassificationModel<?, ?>) model;
CategoricalLabel categoricalLabel = (CategoricalLabel) label;
int numClasses = classificationModel.numClasses();
if (numClasses != categoricalLabel.size()) {
throw new IllegalArgumentException("Expected " + numClasses + " target categories, got " + categoricalLabel.size() + " target categories");
}
}
String featuresCol = model.getFeaturesCol();
List<Feature> features = encoder.getFeatures(featuresCol);
if (model instanceof PredictionModel) {
PredictionModel<?, ?> predictionModel = (PredictionModel<?, ?>) model;
int numFeatures = predictionModel.numFeatures();
if (numFeatures != -1 && features.size() != numFeatures) {
throw new IllegalArgumentException("Expected " + numFeatures + " features, got " + features.size() + " features");
}
}
Schema result = new Schema(label, features);
return result;
}
use of org.jpmml.converter.Feature in project jpmml-sparkml by jpmml.
the class SparkMLEncoder method getFeatures.
public List<Feature> getFeatures(String column) {
List<Feature> features = this.columnFeatures.get(column);
if (features == null) {
FieldName name = FieldName.create(column);
DataField dataField = getDataField(name);
if (dataField == null) {
dataField = createDataField(name);
}
Feature feature;
DataType dataType = dataField.getDataType();
switch(dataType) {
case STRING:
feature = new WildcardFeature(this, dataField);
break;
case INTEGER:
case DOUBLE:
feature = new ContinuousFeature(this, dataField);
break;
case BOOLEAN:
feature = new BooleanFeature(this, dataField);
break;
default:
throw new IllegalArgumentException("Data type " + dataType + " is not supported");
}
return Collections.singletonList(feature);
}
return features;
}
use of org.jpmml.converter.Feature in project jpmml-sparkml by jpmml.
the class SparkMLEncoder method getFeatures.
public List<Feature> getFeatures(String column, int[] indices) {
List<Feature> features = getFeatures(column);
List<Feature> result = new ArrayList<>();
for (int i = 0; i < indices.length; i++) {
int index = indices[i];
Feature feature = features.get(index);
result.add(feature);
}
return result;
}
use of org.jpmml.converter.Feature in project jpmml-sparkml by jpmml.
the class TermFeature method createApply.
public Apply createApply() {
DefineFunction defineFunction = getDefineFunction();
Feature feature = getFeature();
String value = getValue();
Constant constant = PMMLUtil.createConstant(value, DataType.STRING);
return PMMLUtil.createApply(defineFunction.getName(), feature.ref(), constant);
}
use of org.jpmml.converter.Feature in project jpmml-sparkml by jpmml.
the class StandardScalerModelConverter method encodeFeatures.
@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
StandardScalerModel transformer = getTransformer();
List<Feature> features = encoder.getFeatures(transformer.getInputCol());
Vector mean = transformer.mean();
if (transformer.getWithMean() && mean.size() != features.size()) {
throw new IllegalArgumentException();
}
Vector std = transformer.std();
if (transformer.getWithStd() && std.size() != features.size()) {
throw new IllegalArgumentException();
}
List<Feature> result = new ArrayList<>();
for (int i = 0; i < features.size(); i++) {
Feature feature = features.get(i);
ContinuousFeature continuousFeature = feature.toContinuousFeature();
Expression expression = continuousFeature.ref();
if (transformer.getWithMean()) {
double meanValue = mean.apply(i);
if (!ValueUtil.isZero(meanValue)) {
expression = PMMLUtil.createApply("-", expression, PMMLUtil.createConstant(meanValue));
}
}
if (transformer.getWithStd()) {
double stdValue = std.apply(i);
if (!ValueUtil.isOne(stdValue)) {
expression = PMMLUtil.createApply("*", expression, PMMLUtil.createConstant(1d / stdValue));
}
}
DerivedField derivedField = encoder.createDerivedField(formatName(transformer, i), OpType.CONTINUOUS, DataType.DOUBLE, expression);
result.add(new ContinuousFeature(encoder, derivedField));
}
return result;
}
Aggregations