Search in sources :

Example 26 with ModelRecord

use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.

the class KernelModel method loadInternal.

@Override
protected void loadInternal(ModelHeader header, BufferedReader reader) throws IOException {
    long rows = header.getNumRecords();
    dictionary = new StringDictionary(header.getDictionary());
    supportVectors = new ArrayList<>();
    for (long i = 0; i < rows; i++) {
        String line = reader.readLine();
        ModelRecord record = Util.decodeModel(line);
        supportVectors.add(new SupportVector(record));
    }
}
Also used : ModelRecord(com.airbnb.aerosolve.core.ModelRecord) SupportVector(com.airbnb.aerosolve.core.util.SupportVector) StringDictionary(com.airbnb.aerosolve.core.util.StringDictionary)

Example 27 with ModelRecord

use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.

the class LinearModel method loadInternal.

// Loads model from a buffered stream.
@Override
protected void loadInternal(ModelHeader header, BufferedReader reader) throws IOException {
    long rows = header.getNumRecords();
    // Very old models did not set slope and offset so check first.
    if (header.isSetSlope()) {
        slope = header.getSlope();
    }
    if (header.isSetOffset()) {
        offset = header.getOffset();
    }
    weights = new HashMap<>();
    for (long i = 0; i < rows; i++) {
        String line = reader.readLine();
        ModelRecord record = Util.decodeModel(line);
        String family = record.getFeatureFamily();
        String name = record.getFeatureName();
        Map<String, Float> inner = weights.get(family);
        if (inner == null) {
            inner = new HashMap<>();
            weights.put(family, inner);
        }
        float weight = (float) record.getFeatureWeight();
        inner.put(name, weight);
    }
}
Also used : ModelRecord(com.airbnb.aerosolve.core.ModelRecord)

Example 28 with ModelRecord

use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.

the class LowRankLinearModel method loadInternal.

@Override
protected void loadInternal(ModelHeader header, BufferedReader reader) throws IOException {
    long rows = header.getNumRecords();
    labelDictionary = new ArrayList<>();
    for (LabelDictionaryEntry entry : header.getLabelDictionary()) {
        labelDictionary.add(entry);
    }
    buildLabelToIndex();
    labelWeightVector = new HashMap<>();
    embeddingDimension = header.getLabelEmbedding().entrySet().iterator().next().getValue().size();
    for (Map.Entry<String, java.util.List<Double>> labelRepresentation : header.getLabelEmbedding().entrySet()) {
        java.util.List<Double> values = labelRepresentation.getValue();
        String labelKey = labelRepresentation.getKey();
        FloatVector labelWeight = new FloatVector(embeddingDimension);
        for (int i = 0; i < embeddingDimension; i++) {
            labelWeight.set(i, values.get(i).floatValue());
        }
        labelWeightVector.put(labelKey, labelWeight);
    }
    featureWeightVector = new HashMap<>();
    for (long i = 0; i < rows; i++) {
        String line = reader.readLine();
        ModelRecord record = Util.decodeModel(line);
        String family = record.getFeatureFamily();
        String name = record.getFeatureName();
        Map<String, FloatVector> inner = featureWeightVector.get(family);
        if (inner == null) {
            inner = new HashMap<>();
            featureWeightVector.put(family, inner);
        }
        FloatVector vec = new FloatVector(record.getWeightVector().size());
        for (int j = 0; j < record.getWeightVector().size(); j++) {
            vec.values[j] = record.getWeightVector().get(j).floatValue();
        }
        inner.put(name, vec);
    }
}
Also used : LabelDictionaryEntry(com.airbnb.aerosolve.core.LabelDictionaryEntry) FloatVector(com.airbnb.aerosolve.core.util.FloatVector) java.util(java.util) ModelRecord(com.airbnb.aerosolve.core.ModelRecord)

Example 29 with ModelRecord

use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.

the class LowRankLinearModel method save.

public void save(BufferedWriter writer) throws IOException {
    ModelHeader header = new ModelHeader();
    header.setModelType("low_rank_linear");
    long count = 0;
    for (Map.Entry<String, Map<String, FloatVector>> familyMap : featureWeightVector.entrySet()) {
        count += familyMap.getValue().entrySet().size();
    }
    header.setNumRecords(count);
    header.setLabelDictionary(labelDictionary);
    Map<String, java.util.List<Double>> labelEmbedding = new HashMap<>();
    for (Map.Entry<String, FloatVector> labelRepresentation : labelWeightVector.entrySet()) {
        float[] values = labelRepresentation.getValue().getValues();
        ArrayList<Double> arrayList = new ArrayList<>();
        for (int i = 0; i < embeddingDimension; i++) {
            arrayList.add((double) values[i]);
        }
        labelEmbedding.put(labelRepresentation.getKey(), arrayList);
    }
    header.setLabelEmbedding(labelEmbedding);
    ModelRecord headerRec = new ModelRecord();
    headerRec.setModelHeader(header);
    writer.write(Util.encode(headerRec));
    writer.newLine();
    for (Map.Entry<String, Map<String, FloatVector>> familyMap : featureWeightVector.entrySet()) {
        for (Map.Entry<String, FloatVector> feature : familyMap.getValue().entrySet()) {
            ModelRecord record = new ModelRecord();
            record.setFeatureFamily(familyMap.getKey());
            record.setFeatureName(feature.getKey());
            ArrayList<Double> arrayList = new ArrayList<>();
            for (int i = 0; i < feature.getValue().values.length; i++) {
                arrayList.add((double) feature.getValue().values[i]);
            }
            record.setWeightVector(arrayList);
            writer.write(Util.encode(record));
            writer.newLine();
        }
    }
    writer.flush();
}
Also used : FloatVector(com.airbnb.aerosolve.core.util.FloatVector) ModelHeader(com.airbnb.aerosolve.core.ModelHeader) ModelRecord(com.airbnb.aerosolve.core.ModelRecord)

Example 30 with ModelRecord

use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.

the class MlpModel method loadInternal.

@Override
protected void loadInternal(ModelHeader header, BufferedReader reader) throws IOException {
    numHiddenLayers = header.getNumHiddenLayers();
    List<Integer> hiddenNodeNumber = header.getNumberHiddenNodes();
    for (int i = 0; i < hiddenNodeNumber.size(); i++) {
        layerNodeNumber.add(hiddenNodeNumber.get(i));
    }
    // load input layer weights
    long rows = header.getNumRecords();
    for (int i = 0; i < rows; i++) {
        String line = reader.readLine();
        ModelRecord record = Util.decodeModel(line);
        String family = record.getFeatureFamily();
        String name = record.getFeatureName();
        Map<String, FloatVector> inner = inputLayerWeights.get(family);
        if (inner == null) {
            inner = new HashMap<>();
            inputLayerWeights.put(family, inner);
        }
        FloatVector vec = new FloatVector(record.getWeightVector().size());
        for (int j = 0; j < record.getWeightVector().size(); j++) {
            vec.values[j] = record.getWeightVector().get(j).floatValue();
        }
        inner.put(name, vec);
    }
    // load bias and activation function
    for (int i = 0; i < numHiddenLayers + 1; i++) {
        String line = reader.readLine();
        ModelRecord record = Util.decodeModel(line);
        List<Double> arrayList = record.getWeightVector();
        FloatVector layerBias = new FloatVector(arrayList.size());
        for (int j = 0; j < arrayList.size(); j++) {
            layerBias.set(j, arrayList.get(j).floatValue());
        }
        bias.put(i, layerBias);
        activationFunction.add(record.getFunctionForm());
    }
    // load the hiddenLayerWeights, one record per (layer + node)
    for (int i = 0; i < numHiddenLayers; i++) {
        ArrayList<FloatVector> weights = new ArrayList<>();
        for (int j = 0; j < layerNodeNumber.get(i); j++) {
            String line = reader.readLine();
            ModelRecord record = Util.decodeModel(line);
            List<Double> arrayList = record.getWeightVector();
            FloatVector w = new FloatVector(arrayList.size());
            for (int k = 0; k < arrayList.size(); k++) {
                w.set(k, arrayList.get(k).floatValue());
            }
            weights.add(w);
        }
        hiddenLayerWeights.put(i, weights);
    }
}
Also used : FloatVector(com.airbnb.aerosolve.core.util.FloatVector) ModelRecord(com.airbnb.aerosolve.core.ModelRecord)

Aggregations

ModelRecord (com.airbnb.aerosolve.core.ModelRecord)40 ModelHeader (com.airbnb.aerosolve.core.ModelHeader)14 Test (org.junit.Test)11 ArrayList (java.util.ArrayList)10 FloatVector (com.airbnb.aerosolve.core.util.FloatVector)8 Map (java.util.Map)5 FeatureVector (com.airbnb.aerosolve.core.FeatureVector)4 BufferedReader (java.io.BufferedReader)3 BufferedWriter (java.io.BufferedWriter)3 CharArrayWriter (java.io.CharArrayWriter)3 IOException (java.io.IOException)3 StringReader (java.io.StringReader)3 LabelDictionaryEntry (com.airbnb.aerosolve.core.LabelDictionaryEntry)2 AbstractFunction (com.airbnb.aerosolve.core.function.AbstractFunction)2 Function (com.airbnb.aerosolve.core.function.Function)2 NDTreeModelTest (com.airbnb.aerosolve.core.models.NDTreeModelTest)2 SupportVector (com.airbnb.aerosolve.core.util.SupportVector)2 HashMap (java.util.HashMap)2 DebugScoreRecord (com.airbnb.aerosolve.core.DebugScoreRecord)1 MulticlassScoringResult (com.airbnb.aerosolve.core.MulticlassScoringResult)1