Search in sources :

Example 11 with ModelRecord

use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.

the class MaxoutModel method save.

public void save(BufferedWriter writer) throws IOException {
    ModelHeader header = new ModelHeader();
    header.setModelType("maxout");
    header.setNumHidden(numHidden);
    long count = 0;
    for (Map.Entry<String, Map<String, WeightVector>> familyMap : weightVector.entrySet()) {
        for (Map.Entry<String, WeightVector> feature : familyMap.getValue().entrySet()) {
            count++;
        }
    }
    header.setNumRecords(count);
    ModelRecord headerRec = new ModelRecord();
    headerRec.setModelHeader(header);
    writer.write(Util.encode(headerRec));
    writer.newLine();
    for (Map.Entry<String, Map<String, WeightVector>> familyMap : weightVector.entrySet()) {
        for (Map.Entry<String, WeightVector> feature : familyMap.getValue().entrySet()) {
            ModelRecord record = new ModelRecord();
            record.setFeatureFamily(familyMap.getKey());
            record.setFeatureName(feature.getKey());
            ArrayList<Double> arrayList = new ArrayList<Double>();
            for (int i = 0; i < feature.getValue().weights.values.length; i++) {
                arrayList.add((double) feature.getValue().weights.values[i]);
            }
            record.setWeightVector(arrayList);
            record.setScale(feature.getValue().scale);
            writer.write(Util.encode(record));
            writer.newLine();
        }
    }
    writer.flush();
}
Also used : ModelHeader(com.airbnb.aerosolve.core.ModelHeader) ModelRecord(com.airbnb.aerosolve.core.ModelRecord)

Example 12 with ModelRecord

use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.

the class MlpModel method save.

public void save(BufferedWriter writer) throws IOException {
    ModelHeader header = new ModelHeader();
    header.setModelType("multilayer_perceptron");
    header.setNumHiddenLayers(numHiddenLayers);
    ArrayList<Integer> nodeNum = new ArrayList<>();
    for (int i = 0; i < numHiddenLayers + 1; i++) {
        // this includes the number of node at the output layer
        nodeNum.add(layerNodeNumber.get(i));
    }
    header.setNumberHiddenNodes(nodeNum);
    long count = 0;
    for (Map.Entry<String, Map<String, FloatVector>> familyMap : inputLayerWeights.entrySet()) {
        count += familyMap.getValue().entrySet().size();
    }
    // number of record for the input layer weights
    header.setNumRecords(count);
    ModelRecord headerRec = new ModelRecord();
    headerRec.setModelHeader(header);
    writer.write(Util.encode(headerRec));
    writer.newLine();
    // save the input layer weight, one record per feature
    for (Map.Entry<String, Map<String, FloatVector>> familyMap : inputLayerWeights.entrySet()) {
        for (Map.Entry<String, FloatVector> feature : familyMap.getValue().entrySet()) {
            ModelRecord record = new ModelRecord();
            record.setFeatureFamily(familyMap.getKey());
            record.setFeatureName(feature.getKey());
            ArrayList<Double> arrayList = new ArrayList<>();
            for (int i = 0; i < feature.getValue().length(); i++) {
                arrayList.add((double) feature.getValue().values[i]);
            }
            record.setWeightVector(arrayList);
            writer.write(Util.encode(record));
            writer.newLine();
        }
    }
    // save the bias for each layer after input layer, one record per layer
    for (int i = 0; i < numHiddenLayers + 1; i++) {
        ArrayList<Double> arrayList = new ArrayList<>();
        FloatVector layerBias = bias.get(i);
        int n = layerBias.length();
        ModelRecord record = new ModelRecord();
        for (int j = 0; j < n; j++) {
            arrayList.add((double) layerBias.get(j));
        }
        record.setWeightVector(arrayList);
        record.setFunctionForm(activationFunction.get(i));
        writer.write(Util.encode(record));
        writer.newLine();
    }
    // save the hiddenLayerWeights, one record per (layer + node)
    for (int i = 0; i < numHiddenLayers; i++) {
        ArrayList<FloatVector> weights = hiddenLayerWeights.get(i);
        for (int j = 0; j < layerNodeNumber.get(i); j++) {
            FloatVector w = weights.get(j);
            ModelRecord record = new ModelRecord();
            ArrayList<Double> arrayList = new ArrayList<>();
            for (int k = 0; k < w.length(); k++) {
                arrayList.add((double) w.get(k));
            }
            record.setWeightVector(arrayList);
            writer.write(Util.encode(record));
            writer.newLine();
        }
    }
    writer.flush();
}
Also used : FloatVector(com.airbnb.aerosolve.core.util.FloatVector) ModelHeader(com.airbnb.aerosolve.core.ModelHeader) ModelRecord(com.airbnb.aerosolve.core.ModelRecord)

Example 13 with ModelRecord

use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.

the class SplineModel method loadInternal.

@Override
protected void loadInternal(ModelHeader header, BufferedReader reader) throws IOException {
    long rows = header.getNumRecords();
    numBins = header.getNumHidden();
    slope = header.getSlope();
    offset = header.getOffset();
    weightSpline = new HashMap<>();
    for (long i = 0; i < rows; i++) {
        String line = reader.readLine();
        ModelRecord record = Util.decodeModel(line);
        String family = record.getFeatureFamily();
        String name = record.getFeatureName();
        Map<String, WeightSpline> inner = weightSpline.get(family);
        if (inner == null) {
            inner = new HashMap<>();
            weightSpline.put(family, inner);
        }
        float minVal = (float) record.getMinVal();
        float maxVal = (float) record.getMaxVal();
        WeightSpline vec = new WeightSpline(minVal, maxVal, numBins);
        for (int j = 0; j < numBins; j++) {
            vec.splineWeights[j] = record.getWeightVector().get(j).floatValue();
        }
        inner.put(name, vec);
    }
}
Also used : ModelRecord(com.airbnb.aerosolve.core.ModelRecord)

Example 14 with ModelRecord

use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.

the class LinearTest method testLinearToModelRecord.

@Test
public void testLinearToModelRecord() {
    Linear linearFunc = createLinearTestExample();
    ModelRecord record = linearFunc.toModelRecord("family", "name");
    assertEquals(record.getFeatureFamily(), "family");
    assertEquals(record.getFeatureName(), "name");
    List<Double> weightVector = record.getWeightVector();
    assertEquals(0.2f, weightVector.get(0).floatValue(), 0.01f);
    assertEquals(1.5f, weightVector.get(1).floatValue(), 0.01f);
    assertEquals(-6.0f, record.getMinVal(), 0.01f);
    assertEquals(5.0f, record.getMaxVal(), 0.01f);
}
Also used : ModelRecord(com.airbnb.aerosolve.core.ModelRecord) Test(org.junit.Test)

Example 15 with ModelRecord

use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.

the class MultiDimensionSplineTest method modelRecord.

@Test
public void modelRecord() {
    MultiDimensionSpline a = getMultiDimensionSpline();
    set(a);
    ModelRecord record = a.toModelRecord("", "");
    MultiDimensionSpline b = new MultiDimensionSpline(record);
    assertEquals(0.40389338302461303, a.evaluate(3.0f, 3.0f), 0.0001);
    assertEquals(0.40389338302461303, b.evaluate(3.0f, 3.0f), 0.0001);
}
Also used : ModelRecord(com.airbnb.aerosolve.core.ModelRecord) NDTreeModelTest(com.airbnb.aerosolve.core.models.NDTreeModelTest) Test(org.junit.Test)

Aggregations

ModelRecord (com.airbnb.aerosolve.core.ModelRecord)40 ModelHeader (com.airbnb.aerosolve.core.ModelHeader)14 Test (org.junit.Test)11 ArrayList (java.util.ArrayList)10 FloatVector (com.airbnb.aerosolve.core.util.FloatVector)8 Map (java.util.Map)5 FeatureVector (com.airbnb.aerosolve.core.FeatureVector)4 BufferedReader (java.io.BufferedReader)3 BufferedWriter (java.io.BufferedWriter)3 CharArrayWriter (java.io.CharArrayWriter)3 IOException (java.io.IOException)3 StringReader (java.io.StringReader)3 LabelDictionaryEntry (com.airbnb.aerosolve.core.LabelDictionaryEntry)2 AbstractFunction (com.airbnb.aerosolve.core.function.AbstractFunction)2 Function (com.airbnb.aerosolve.core.function.Function)2 NDTreeModelTest (com.airbnb.aerosolve.core.models.NDTreeModelTest)2 SupportVector (com.airbnb.aerosolve.core.util.SupportVector)2 HashMap (java.util.HashMap)2 DebugScoreRecord (com.airbnb.aerosolve.core.DebugScoreRecord)1 MulticlassScoringResult (com.airbnb.aerosolve.core.MulticlassScoringResult)1