Search in sources :

Example 1 with LabelDictionaryEntry

use of com.airbnb.aerosolve.core.LabelDictionaryEntry in project aerosolve by airbnb.

the class LowRankLinearModelTest method makeLabelDictionary.

ArrayList<LabelDictionaryEntry> makeLabelDictionary() {
    ArrayList<LabelDictionaryEntry> labelDictionary = new ArrayList<>();
    // construct label dictionary
    LabelDictionaryEntry animalLabel = new LabelDictionaryEntry();
    animalLabel.setLabel("A");
    animalLabel.setCount(1);
    labelDictionary.add(animalLabel);
    LabelDictionaryEntry colorLabel = new LabelDictionaryEntry();
    colorLabel.setLabel("C");
    colorLabel.setCount(1);
    labelDictionary.add(colorLabel);
    LabelDictionaryEntry fruitLabel = new LabelDictionaryEntry();
    fruitLabel.setLabel("F");
    fruitLabel.setCount(1);
    labelDictionary.add(fruitLabel);
    return labelDictionary;
}
Also used : LabelDictionaryEntry(com.airbnb.aerosolve.core.LabelDictionaryEntry) ArrayList(java.util.ArrayList)

Example 2 with LabelDictionaryEntry

use of com.airbnb.aerosolve.core.LabelDictionaryEntry in project aerosolve by airbnb.

the class FullRankLinearModel method loadInternal.

@Override
protected void loadInternal(ModelHeader header, BufferedReader reader) throws IOException {
    long rows = header.getNumRecords();
    labelDictionary = new ArrayList<>();
    for (LabelDictionaryEntry entry : header.getLabelDictionary()) {
        labelDictionary.add(entry);
    }
    buildLabelToIndex();
    weightVector = new HashMap<>();
    for (long i = 0; i < rows; i++) {
        String line = reader.readLine();
        ModelRecord record = Util.decodeModel(line);
        String family = record.getFeatureFamily();
        String name = record.getFeatureName();
        Map<String, FloatVector> inner = weightVector.get(family);
        if (inner == null) {
            inner = new HashMap<>();
            weightVector.put(family, inner);
        }
        FloatVector vec = new FloatVector(record.getWeightVector().size());
        for (int j = 0; j < record.getWeightVector().size(); j++) {
            vec.values[j] = record.getWeightVector().get(j).floatValue();
        }
        inner.put(name, vec);
    }
}
Also used : LabelDictionaryEntry(com.airbnb.aerosolve.core.LabelDictionaryEntry) ModelRecord(com.airbnb.aerosolve.core.ModelRecord) FloatVector(com.airbnb.aerosolve.core.util.FloatVector)

Example 3 with LabelDictionaryEntry

use of com.airbnb.aerosolve.core.LabelDictionaryEntry in project aerosolve by airbnb.

the class LowRankLinearModel method loadInternal.

@Override
protected void loadInternal(ModelHeader header, BufferedReader reader) throws IOException {
    long rows = header.getNumRecords();
    labelDictionary = new ArrayList<>();
    for (LabelDictionaryEntry entry : header.getLabelDictionary()) {
        labelDictionary.add(entry);
    }
    buildLabelToIndex();
    labelWeightVector = new HashMap<>();
    embeddingDimension = header.getLabelEmbedding().entrySet().iterator().next().getValue().size();
    for (Map.Entry<String, java.util.List<Double>> labelRepresentation : header.getLabelEmbedding().entrySet()) {
        java.util.List<Double> values = labelRepresentation.getValue();
        String labelKey = labelRepresentation.getKey();
        FloatVector labelWeight = new FloatVector(embeddingDimension);
        for (int i = 0; i < embeddingDimension; i++) {
            labelWeight.set(i, values.get(i).floatValue());
        }
        labelWeightVector.put(labelKey, labelWeight);
    }
    featureWeightVector = new HashMap<>();
    for (long i = 0; i < rows; i++) {
        String line = reader.readLine();
        ModelRecord record = Util.decodeModel(line);
        String family = record.getFeatureFamily();
        String name = record.getFeatureName();
        Map<String, FloatVector> inner = featureWeightVector.get(family);
        if (inner == null) {
            inner = new HashMap<>();
            featureWeightVector.put(family, inner);
        }
        FloatVector vec = new FloatVector(record.getWeightVector().size());
        for (int j = 0; j < record.getWeightVector().size(); j++) {
            vec.values[j] = record.getWeightVector().get(j).floatValue();
        }
        inner.put(name, vec);
    }
}
Also used : LabelDictionaryEntry(com.airbnb.aerosolve.core.LabelDictionaryEntry) FloatVector(com.airbnb.aerosolve.core.util.FloatVector) java.util(java.util) ModelRecord(com.airbnb.aerosolve.core.ModelRecord)

Aggregations

LabelDictionaryEntry (com.airbnb.aerosolve.core.LabelDictionaryEntry)3 ModelRecord (com.airbnb.aerosolve.core.ModelRecord)2 FloatVector (com.airbnb.aerosolve.core.util.FloatVector)2 java.util (java.util)1 ArrayList (java.util.ArrayList)1