use of org.dmg.pmml.VectorInstanceDocument.VectorInstance in project knime-core by knime.
the class PMMLSVMTranslator method addVectorDictionary.
/**
* Adds the vector dictionary to the SVM model.
*
* @param svmModel the SVM model to add the dictionary to
* @param learningCol a list with the names of the learning columns
*/
private void addVectorDictionary(final SupportVectorMachineModel svmModel, final List<String> learningCol) {
VectorDictionary dict = svmModel.addNewVectorDictionary();
Set<DoubleVector> supportVectors = new LinkedHashSet<DoubleVector>();
for (Svm svm : m_svms) {
supportVectors.addAll(Arrays.asList(svm.getSupportVectors()));
}
dict.setNumberOfVectors(BigInteger.valueOf(supportVectors.size()));
VectorFields vectorFields = dict.addNewVectorFields();
vectorFields.setNumberOfFields(BigInteger.valueOf(learningCol.size()));
for (String field : learningCol) {
vectorFields.addNewFieldRef().setField(m_nameMapper.getDerivedFieldName(field));
}
for (DoubleVector vector : supportVectors) {
VectorInstance vectorInstance = dict.addNewVectorInstance();
vectorInstance.setId(vector.getClassValue() + CLASS_KEY_SEPARATOR + vector.getKey().getString());
REALSparseArray pmmlRealSparseArray = vectorInstance.addNewREALSparseArray1();
int nrValues = vector.getNumberValues();
pmmlRealSparseArray.setN(BigInteger.valueOf(nrValues));
// set Indices and Entries
List<String> indicesList = new ArrayList<String>();
List<String> entriesList = new ArrayList<String>();
for (int i = 1; i <= nrValues; i++) {
indicesList.add(String.valueOf(i));
entriesList.add(String.valueOf(vector.getValue(i - 1)));
}
pmmlRealSparseArray.setIndices(indicesList);
pmmlRealSparseArray.setREALEntries(entriesList);
}
}
use of org.dmg.pmml.VectorInstanceDocument.VectorInstance in project knime-core by knime.
the class PMMLSVMTranslator method initSVMs.
/**
* @param svmModel
*/
private void initSVMs(final SupportVectorMachineModel svmModel) {
final Map<String, ArrayList<Double>> vectors = new LinkedHashMap<String, ArrayList<Double>>();
for (VectorInstance vectorInstance : svmModel.getVectorDictionary().getVectorInstanceArray()) {
REALSparseArray sparseArray = vectorInstance.getREALSparseArray1();
ArrayList<Double> values = new ArrayList<Double>();
for (Object realEntry : sparseArray.getREALEntries()) {
values.add((Double) realEntry);
}
vectors.put(vectorInstance.getId(), values);
}
for (SupportVectorMachine supportVectorMachine : svmModel.getSupportVectorMachineArray()) {
// collect support vectors
SupportVectors svs = supportVectorMachine.getSupportVectors();
DoubleVector[] supportVectors = new DoubleVector[svs.getNumberOfSupportVectors().intValue()];
SupportVector[] supportVectorArray = svs.getSupportVectorArray();
for (int i = 0; i < supportVectorArray.length; i++) {
SupportVector supportVector = supportVectorArray[i];
String vectorId = supportVector.getVectorId();
String classValue = getClassValue(vectorId);
supportVectors[i] = new DoubleVector(new RowKey(vectorId), vectors.get(vectorId), classValue);
}
Coefficients coef = supportVectorMachine.getCoefficients();
double threshold = coef.getAbsoluteValue();
// collect coefficients
Coefficient[] coefficientArray = coef.getCoefficientArray();
double[] alpha = new double[coefficientArray.length];
for (int i = 0; i < coefficientArray.length; i++) {
/**
* The alpha in KNIME is always positive. When calculating the
* distance it is multiplied with a target factor that is 1 or
* -1 depending on whether we have a positive example or not.
* (see {@link SVM#getTargetAlphas()} for details)
*/
alpha[i] = Math.abs(coefficientArray[i].getValue());
}
m_svms.add(new Svm(supportVectors, alpha, supportVectorMachine.getTargetCategory(), threshold, m_kernel));
/* The KNIME internal representation requires two SVMs for the
* binary classification case. Therefore add a second SVM with the
* same configuration as the first one except for the negative
* threshold. */
if (svmModel.getSupportVectorMachineArray().length == 1) {
m_svms.add(new Svm(supportVectors.clone(), alpha, supportVectorMachine.getAlternateTargetCategory(), threshold * -1, m_kernel));
}
}
}
Aggregations