Search in sources :

Example 21 with FloatVector

use of com.airbnb.aerosolve.core.util.FloatVector in project aerosolve by airbnb.

the class MlpModel method loadInternal.

@Override
protected void loadInternal(ModelHeader header, BufferedReader reader) throws IOException {
    numHiddenLayers = header.getNumHiddenLayers();
    List<Integer> hiddenNodeNumber = header.getNumberHiddenNodes();
    for (int i = 0; i < hiddenNodeNumber.size(); i++) {
        layerNodeNumber.add(hiddenNodeNumber.get(i));
    }
    // load input layer weights
    long rows = header.getNumRecords();
    for (int i = 0; i < rows; i++) {
        String line = reader.readLine();
        ModelRecord record = Util.decodeModel(line);
        String family = record.getFeatureFamily();
        String name = record.getFeatureName();
        Map<String, FloatVector> inner = inputLayerWeights.get(family);
        if (inner == null) {
            inner = new HashMap<>();
            inputLayerWeights.put(family, inner);
        }
        FloatVector vec = new FloatVector(record.getWeightVector().size());
        for (int j = 0; j < record.getWeightVector().size(); j++) {
            vec.values[j] = record.getWeightVector().get(j).floatValue();
        }
        inner.put(name, vec);
    }
    // load bias and activation function
    for (int i = 0; i < numHiddenLayers + 1; i++) {
        String line = reader.readLine();
        ModelRecord record = Util.decodeModel(line);
        List<Double> arrayList = record.getWeightVector();
        FloatVector layerBias = new FloatVector(arrayList.size());
        for (int j = 0; j < arrayList.size(); j++) {
            layerBias.set(j, arrayList.get(j).floatValue());
        }
        bias.put(i, layerBias);
        activationFunction.add(record.getFunctionForm());
    }
    // load the hiddenLayerWeights, one record per (layer + node)
    for (int i = 0; i < numHiddenLayers; i++) {
        ArrayList<FloatVector> weights = new ArrayList<>();
        for (int j = 0; j < layerNodeNumber.get(i); j++) {
            String line = reader.readLine();
            ModelRecord record = Util.decodeModel(line);
            List<Double> arrayList = record.getWeightVector();
            FloatVector w = new FloatVector(arrayList.size());
            for (int k = 0; k < arrayList.size(); k++) {
                w.set(k, arrayList.get(k).floatValue());
            }
            weights.add(w);
        }
        hiddenLayerWeights.put(i, weights);
    }
}
Also used : FloatVector(com.airbnb.aerosolve.core.util.FloatVector) ModelRecord(com.airbnb.aerosolve.core.ModelRecord)

Example 22 with FloatVector

use of com.airbnb.aerosolve.core.util.FloatVector in project aerosolve by airbnb.

the class MaxoutModel method getResponse.

public FloatVector getResponse(Map<String, Map<String, Double>> flatFeatures) {
    FloatVector sum = new FloatVector(numHidden);
    for (Map.Entry<String, Map<String, Double>> entry : flatFeatures.entrySet()) {
        Map<String, WeightVector> family = weightVector.get(entry.getKey());
        if (family != null) {
            for (Map.Entry<String, Double> feature : entry.getValue().entrySet()) {
                WeightVector hidden = family.get(feature.getKey());
                if (hidden != null) {
                    sum.multiplyAdd(feature.getValue().floatValue() * hidden.scale, hidden.weights);
                }
            }
        }
    }
    sum.add(bias.weights);
    return sum;
}
Also used : FloatVector(com.airbnb.aerosolve.core.util.FloatVector)

Example 23 with FloatVector

use of com.airbnb.aerosolve.core.util.FloatVector in project aerosolve by airbnb.

the class MaxoutModel method scoreFlatFeatures.

public float scoreFlatFeatures(Map<String, Map<String, Double>> flatFeatures) {
    FloatVector response = getResponse(flatFeatures);
    FloatVector.MinMaxResult result = response.getMinMaxResult();
    return result.maxValue - result.minValue;
}
Also used : FloatVector(com.airbnb.aerosolve.core.util.FloatVector)

Example 24 with FloatVector

use of com.airbnb.aerosolve.core.util.FloatVector in project aerosolve by airbnb.

the class MlpModelTest method makeMlpModel.

public MlpModel makeMlpModel(FunctionForm func) {
    // construct a network with 1 hidden layer
    // and there are 3 nodes in the hidden layer
    ArrayList nodeNum = new ArrayList(2);
    nodeNum.add(3);
    nodeNum.add(1);
    // assume bias at each node are zeros
    ArrayList activations = new ArrayList();
    activations.add(func);
    activations.add(func);
    MlpModel model = new MlpModel(activations, nodeNum);
    // set input layer
    HashMap inputLayer = new HashMap<>();
    HashMap inner = new HashMap<>();
    FloatVector f11 = new FloatVector(3);
    f11.set(0, 0.0f);
    f11.set(1, 1.0f);
    f11.set(2, 1.0f);
    FloatVector f12 = new FloatVector(3);
    f12.set(0, 1.0f);
    f12.set(1, 1.0f);
    f12.set(2, 0.0f);
    inner.put("a", f11);
    inner.put("b", f12);
    inputLayer.put("in", inner);
    model.setInputLayerWeights(inputLayer);
    // set hidden layer
    HashMap hiddenLayer = new HashMap<>();
    FloatVector f21 = new FloatVector(1);
    FloatVector f22 = new FloatVector(1);
    FloatVector f23 = new FloatVector(1);
    f21.set(0, 0.5f);
    f22.set(0, 1.0f);
    f23.set(0, 2.0f);
    ArrayList hidden = new ArrayList(3);
    hidden.add(f21);
    hidden.add(f22);
    hidden.add(f23);
    hiddenLayer.put(0, hidden);
    model.setHiddenLayerWeights(hiddenLayer);
    return model;
}
Also used : HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) FloatVector(com.airbnb.aerosolve.core.util.FloatVector)

Example 25 with FloatVector

use of com.airbnb.aerosolve.core.util.FloatVector in project aerosolve by airbnb.

the class LowRankLinearModelTest method makeLowRankLinearModel.

LowRankLinearModel makeLowRankLinearModel() {
    // A naive model with three classes 'animal', 'color' and 'fruit'
    // and the size of embedding D = number of labels, W is an identity matrix
    LowRankLinearModel model = new LowRankLinearModel();
    model.setEmbeddingDimension(3);
    model.setLabelDictionary(makeLabelDictionary());
    // construct featureWeightVector
    Map<String, Map<String, FloatVector>> featureWeights = new HashMap<>();
    Map<String, FloatVector> animalFeatures = new HashMap<>();
    Map<String, FloatVector> colorFeatures = new HashMap<>();
    Map<String, FloatVector> fruitFeatures = new HashMap<>();
    String[] animalWords = { "cat", "dog", "horse", "fish" };
    String[] colorWords = { "red", "black", "blue", "white", "yellow" };
    String[] fruitWords = { "apple", "kiwi", "pear", "peach" };
    float[] animalFeature = { 1.0f, 0.0f, 0.0f };
    float[] colorFeature = { 0.0f, 1.0f, 0.0f };
    float[] fruitFeature = { 0.0f, 0.0f, 1.0f };
    for (String word : animalWords) {
        animalFeatures.put(word, new FloatVector(animalFeature));
    }
    for (String word : colorWords) {
        colorFeatures.put(word, new FloatVector(colorFeature));
    }
    for (String word : fruitWords) {
        fruitFeatures.put(word, new FloatVector(fruitFeature));
    }
    featureWeights.put("a", animalFeatures);
    featureWeights.put("c", colorFeatures);
    featureWeights.put("f", fruitFeatures);
    model.setFeatureWeightVector(featureWeights);
    // set labelWeightVector
    model.setLabelWeightVector(makeLabelWeightVector());
    model.buildLabelToIndex();
    return model;
}
Also used : HashMap(java.util.HashMap) FloatVector(com.airbnb.aerosolve.core.util.FloatVector) HashMap(java.util.HashMap) Map(java.util.Map)

Aggregations

FloatVector (com.airbnb.aerosolve.core.util.FloatVector)26 ModelRecord (com.airbnb.aerosolve.core.ModelRecord)8 HashMap (java.util.HashMap)5 ModelHeader (com.airbnb.aerosolve.core.ModelHeader)4 MulticlassScoringResult (com.airbnb.aerosolve.core.MulticlassScoringResult)3 Map (java.util.Map)3 LabelDictionaryEntry (com.airbnb.aerosolve.core.LabelDictionaryEntry)2 SupportVector (com.airbnb.aerosolve.core.util.SupportVector)2 ArrayList (java.util.ArrayList)2 DebugScoreRecord (com.airbnb.aerosolve.core.DebugScoreRecord)1 FeatureVector (com.airbnb.aerosolve.core.FeatureVector)1 BufferedReader (java.io.BufferedReader)1 BufferedWriter (java.io.BufferedWriter)1 CharArrayWriter (java.io.CharArrayWriter)1 IOException (java.io.IOException)1 StringReader (java.io.StringReader)1 java.util (java.util)1 AbstractMap (java.util.AbstractMap)1 Test (org.junit.Test)1