Search in sources :

Example 11 with ModelHeader

use of com.airbnb.aerosolve.core.ModelHeader in project aerosolve by airbnb.

the class ModelFactory method createFromReader.

public static Optional<AbstractModel> createFromReader(BufferedReader reader) throws IOException {
    Optional<AbstractModel> model = Optional.absent();
    String headerLine = reader.readLine();
    ModelRecord record = Util.decodeModel(headerLine);
    if (record == null) {
        log.error("Could not decode header " + headerLine);
        return model;
    }
    ModelHeader header = record.getModelHeader();
    if (header != null) {
        AbstractModel result = createByName(header.getModelType());
        if (result != null) {
            result.loadInternal(header, reader);
            model = Optional.of(result);
        }
    }
    return model;
}
Also used : ModelHeader(com.airbnb.aerosolve.core.ModelHeader) ModelRecord(com.airbnb.aerosolve.core.ModelRecord)

Example 12 with ModelHeader

use of com.airbnb.aerosolve.core.ModelHeader in project aerosolve by airbnb.

the class SplineModel method save.

@Override
public void save(BufferedWriter writer) throws IOException {
    ModelHeader header = new ModelHeader();
    header.setModelType("spline");
    header.setNumHidden(numBins);
    header.setSlope(slope);
    header.setOffset(offset);
    long count = 0;
    for (Map.Entry<String, Map<String, WeightSpline>> familyMap : weightSpline.entrySet()) {
        for (Map.Entry<String, WeightSpline> feature : familyMap.getValue().entrySet()) {
            count++;
        }
    }
    header.setNumRecords(count);
    ModelRecord headerRec = new ModelRecord();
    headerRec.setModelHeader(header);
    writer.write(Util.encode(headerRec));
    writer.newLine();
    for (Map.Entry<String, Map<String, WeightSpline>> familyMap : weightSpline.entrySet()) {
        for (Map.Entry<String, WeightSpline> feature : familyMap.getValue().entrySet()) {
            ModelRecord record = new ModelRecord();
            record.setFeatureFamily(familyMap.getKey());
            record.setFeatureName(feature.getKey());
            ArrayList<Double> arrayList = new ArrayList<Double>();
            for (int i = 0; i < feature.getValue().splineWeights.length; i++) {
                arrayList.add((double) feature.getValue().splineWeights[i]);
            }
            record.setWeightVector(arrayList);
            record.setMinVal(feature.getValue().spline.getMinVal());
            record.setMaxVal(feature.getValue().spline.getMaxVal());
            writer.write(Util.encode(record));
            writer.newLine();
        }
    }
    writer.flush();
}
Also used : ModelHeader(com.airbnb.aerosolve.core.ModelHeader) ModelRecord(com.airbnb.aerosolve.core.ModelRecord)

Example 13 with ModelHeader

use of com.airbnb.aerosolve.core.ModelHeader in project aerosolve by airbnb.

the class MlpModelTest method testLoad.

@Test
public void testLoad() {
    CharArrayWriter charWriter = new CharArrayWriter();
    BufferedWriter writer = new BufferedWriter(charWriter);
    // create header record
    ModelHeader header = new ModelHeader();
    header.setModelType("multilayer_perceptron");
    header.setNumHiddenLayers(1);
    ArrayList<Integer> nodeNum = new ArrayList<>(2);
    nodeNum.add(3);
    nodeNum.add(1);
    header.setNumberHiddenNodes(nodeNum);
    header.setNumRecords(2);
    ModelRecord record1 = new ModelRecord();
    record1.setModelHeader(header);
    // create records for input layer
    ModelRecord record2 = new ModelRecord();
    record2.setFeatureFamily("in");
    record2.setFeatureName("a");
    ArrayList<Double> in1 = new ArrayList<>();
    in1.add(0.0);
    in1.add(1.0);
    in1.add(1.0);
    record2.setWeightVector(in1);
    ModelRecord record3 = new ModelRecord();
    record3.setFeatureFamily("in");
    record3.setFeatureName("b");
    ArrayList<Double> in2 = new ArrayList<>();
    in2.add(1.0);
    in2.add(1.0);
    in2.add(0.0);
    record3.setWeightVector(in2);
    // create records for bias
    ModelRecord record4 = new ModelRecord();
    ArrayList<Double> b1 = new ArrayList<>();
    b1.add(0.0);
    b1.add(0.0);
    b1.add(0.0);
    record4.setWeightVector(b1);
    record4.setFunctionForm(FunctionForm.RELU);
    ModelRecord record5 = new ModelRecord();
    ArrayList<Double> b2 = new ArrayList<>();
    b2.add(0.0);
    record5.setWeightVector(b2);
    record5.setFunctionForm(FunctionForm.RELU);
    // create records for hidden layer
    ModelRecord record6 = new ModelRecord();
    ArrayList<Double> h1 = new ArrayList<>();
    h1.add(0.5);
    record6.setWeightVector(h1);
    ModelRecord record7 = new ModelRecord();
    ArrayList<Double> h2 = new ArrayList<>();
    h2.add(1.0);
    record7.setWeightVector(h2);
    ModelRecord record8 = new ModelRecord();
    ArrayList<Double> h3 = new ArrayList<>();
    h3.add(2.0);
    record8.setWeightVector(h3);
    try {
        writer.write(Util.encode(record1) + "\n");
        writer.write(Util.encode(record2) + "\n");
        writer.write(Util.encode(record3) + "\n");
        writer.write(Util.encode(record4) + "\n");
        writer.write(Util.encode(record5) + "\n");
        writer.write(Util.encode(record6) + "\n");
        writer.write(Util.encode(record7) + "\n");
        writer.write(Util.encode(record8) + "\n");
        writer.close();
    } catch (IOException e) {
        assertTrue("Could not write", false);
    }
    String serialized = charWriter.toString();
    assertTrue(serialized.length() > 0);
    StringReader strReader = new StringReader(serialized);
    BufferedReader reader = new BufferedReader(strReader);
    FeatureVector fv = makeFeatureVector();
    try {
        Optional<AbstractModel> model = ModelFactory.createFromReader(reader);
        assertTrue(model.isPresent());
        float s = model.get().scoreItem(fv);
        assertEquals(s, 6.0f, 1e-10f);
    } catch (IOException e) {
        assertTrue("Could not read", false);
    }
}
Also used : FeatureVector(com.airbnb.aerosolve.core.FeatureVector) ArrayList(java.util.ArrayList) IOException(java.io.IOException) CharArrayWriter(java.io.CharArrayWriter) BufferedWriter(java.io.BufferedWriter) ModelHeader(com.airbnb.aerosolve.core.ModelHeader) StringReader(java.io.StringReader) BufferedReader(java.io.BufferedReader) ModelRecord(com.airbnb.aerosolve.core.ModelRecord) Test(org.junit.Test)

Example 14 with ModelHeader

use of com.airbnb.aerosolve.core.ModelHeader in project aerosolve by airbnb.

the class AdditiveModelTest method testLoad.

@Test
public void testLoad() {
    CharArrayWriter charWriter = new CharArrayWriter();
    BufferedWriter writer = new BufferedWriter(charWriter);
    ModelHeader header = new ModelHeader();
    header.setModelType("additive");
    header.setNumRecords(4);
    ArrayList<Double> ws = new ArrayList<Double>();
    ws.add(5.0);
    ws.add(10.0);
    ws.add(-20.0);
    ArrayList<Double> wl = new ArrayList<Double>();
    wl.add(1.0);
    wl.add(2.0);
    ModelRecord record1 = new ModelRecord();
    record1.setModelHeader(header);
    ModelRecord record2 = new ModelRecord();
    record2.setFunctionForm(FunctionForm.Spline);
    record2.setFeatureFamily("spline_float");
    record2.setFeatureName("aaa");
    record2.setWeightVector(ws);
    record2.setMinVal(1.0);
    record2.setMaxVal(3.0);
    ModelRecord record3 = new ModelRecord();
    record3.setFunctionForm(FunctionForm.Spline);
    record3.setFeatureFamily("spline_string");
    record3.setFeatureName("bbb");
    record3.setWeightVector(ws);
    record3.setMinVal(1.0);
    record3.setMaxVal(2.0);
    ModelRecord record4 = new ModelRecord();
    record4.setFunctionForm(FunctionForm.Linear);
    record4.setFeatureFamily("linear_float");
    record4.setFeatureName("ccc");
    record4.setWeightVector(wl);
    ModelRecord record5 = new ModelRecord();
    record5.setFunctionForm(FunctionForm.Linear);
    record5.setFeatureFamily("linear_string");
    record5.setFeatureName("ddd");
    record5.setWeightVector(wl);
    try {
        writer.write(Util.encode(record1) + "\n");
        writer.write(Util.encode(record2) + "\n");
        writer.write(Util.encode(record3) + "\n");
        writer.write(Util.encode(record4) + "\n");
        writer.write(Util.encode(record5) + "\n");
        writer.close();
    } catch (IOException e) {
        assertTrue("Could not write", false);
    }
    String serialized = charWriter.toString();
    assertTrue(serialized.length() > 0);
    StringReader strReader = new StringReader(serialized);
    BufferedReader reader = new BufferedReader(strReader);
    FeatureVector featureVector = makeFeatureVector(2.0f, 7.0f);
    try {
        Optional<AbstractModel> model = ModelFactory.createFromReader(reader);
        assertTrue(model.isPresent());
        float score = model.get().scoreItem(featureVector);
        assertEquals(8.0f + 10.0f + 15.0f, score, 0.001f);
    } catch (IOException e) {
        assertTrue("Could not read", false);
    }
}
Also used : FeatureVector(com.airbnb.aerosolve.core.FeatureVector) ArrayList(java.util.ArrayList) IOException(java.io.IOException) CharArrayWriter(java.io.CharArrayWriter) BufferedWriter(java.io.BufferedWriter) ModelHeader(com.airbnb.aerosolve.core.ModelHeader) StringReader(java.io.StringReader) BufferedReader(java.io.BufferedReader) ModelRecord(com.airbnb.aerosolve.core.ModelRecord) Test(org.junit.Test)

Aggregations

ModelHeader (com.airbnb.aerosolve.core.ModelHeader)14 ModelRecord (com.airbnb.aerosolve.core.ModelRecord)14 FeatureVector (com.airbnb.aerosolve.core.FeatureVector)4 FloatVector (com.airbnb.aerosolve.core.util.FloatVector)4 Test (org.junit.Test)4 BufferedReader (java.io.BufferedReader)3 BufferedWriter (java.io.BufferedWriter)3 CharArrayWriter (java.io.CharArrayWriter)3 IOException (java.io.IOException)3 StringReader (java.io.StringReader)3 ArrayList (java.util.ArrayList)3 HashMap (java.util.HashMap)2 Map (java.util.Map)2 MulticlassScoringResult (com.airbnb.aerosolve.core.MulticlassScoringResult)1 AbstractFunction (com.airbnb.aerosolve.core.function.AbstractFunction)1 Function (com.airbnb.aerosolve.core.function.Function)1 SupportVector (com.airbnb.aerosolve.core.util.SupportVector)1 AbstractMap (java.util.AbstractMap)1