Search in sources :

Example 36 with ModelRecord

use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.

the class MultiDimensionSplineTest method modelRecord1D.

@Test
public void modelRecord1D() {
    MultiDimensionSpline a = getMultiDimensionSpline1D();
    set1D(a);
    ModelRecord record = a.toModelRecord("", "");
    MultiDimensionSpline b = new MultiDimensionSpline(record);
    assertEquals(0.7, a.evaluate(3.5f), 0.0001);
    assertEquals(0.7, b.evaluate(3.5f), 0.0001);
}
Also used : ModelRecord(com.airbnb.aerosolve.core.ModelRecord) NDTreeModelTest(com.airbnb.aerosolve.core.models.NDTreeModelTest) Test(org.junit.Test)

Example 37 with ModelRecord

use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.

the class Linear method toModelRecord.

@Override
public ModelRecord toModelRecord(String featureFamily, String featureName) {
    ModelRecord record = new ModelRecord();
    record.setFunctionForm(FunctionForm.Linear);
    record.setFeatureFamily(featureFamily);
    record.setFeatureName(featureName);
    record.setMinVal(minVal);
    record.setMaxVal(maxVal);
    ArrayList<Double> arrayList = new ArrayList<Double>();
    arrayList.add((double) weights[0]);
    arrayList.add((double) weights[1]);
    record.setWeightVector(arrayList);
    return record;
}
Also used : ArrayList(java.util.ArrayList) ModelRecord(com.airbnb.aerosolve.core.ModelRecord)

Example 38 with ModelRecord

use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.

the class MlpModelTest method testLoad.

@Test
public void testLoad() {
    CharArrayWriter charWriter = new CharArrayWriter();
    BufferedWriter writer = new BufferedWriter(charWriter);
    // create header record
    ModelHeader header = new ModelHeader();
    header.setModelType("multilayer_perceptron");
    header.setNumHiddenLayers(1);
    ArrayList<Integer> nodeNum = new ArrayList<>(2);
    nodeNum.add(3);
    nodeNum.add(1);
    header.setNumberHiddenNodes(nodeNum);
    header.setNumRecords(2);
    ModelRecord record1 = new ModelRecord();
    record1.setModelHeader(header);
    // create records for input layer
    ModelRecord record2 = new ModelRecord();
    record2.setFeatureFamily("in");
    record2.setFeatureName("a");
    ArrayList<Double> in1 = new ArrayList<>();
    in1.add(0.0);
    in1.add(1.0);
    in1.add(1.0);
    record2.setWeightVector(in1);
    ModelRecord record3 = new ModelRecord();
    record3.setFeatureFamily("in");
    record3.setFeatureName("b");
    ArrayList<Double> in2 = new ArrayList<>();
    in2.add(1.0);
    in2.add(1.0);
    in2.add(0.0);
    record3.setWeightVector(in2);
    // create records for bias
    ModelRecord record4 = new ModelRecord();
    ArrayList<Double> b1 = new ArrayList<>();
    b1.add(0.0);
    b1.add(0.0);
    b1.add(0.0);
    record4.setWeightVector(b1);
    record4.setFunctionForm(FunctionForm.RELU);
    ModelRecord record5 = new ModelRecord();
    ArrayList<Double> b2 = new ArrayList<>();
    b2.add(0.0);
    record5.setWeightVector(b2);
    record5.setFunctionForm(FunctionForm.RELU);
    // create records for hidden layer
    ModelRecord record6 = new ModelRecord();
    ArrayList<Double> h1 = new ArrayList<>();
    h1.add(0.5);
    record6.setWeightVector(h1);
    ModelRecord record7 = new ModelRecord();
    ArrayList<Double> h2 = new ArrayList<>();
    h2.add(1.0);
    record7.setWeightVector(h2);
    ModelRecord record8 = new ModelRecord();
    ArrayList<Double> h3 = new ArrayList<>();
    h3.add(2.0);
    record8.setWeightVector(h3);
    try {
        writer.write(Util.encode(record1) + "\n");
        writer.write(Util.encode(record2) + "\n");
        writer.write(Util.encode(record3) + "\n");
        writer.write(Util.encode(record4) + "\n");
        writer.write(Util.encode(record5) + "\n");
        writer.write(Util.encode(record6) + "\n");
        writer.write(Util.encode(record7) + "\n");
        writer.write(Util.encode(record8) + "\n");
        writer.close();
    } catch (IOException e) {
        assertTrue("Could not write", false);
    }
    String serialized = charWriter.toString();
    assertTrue(serialized.length() > 0);
    StringReader strReader = new StringReader(serialized);
    BufferedReader reader = new BufferedReader(strReader);
    FeatureVector fv = makeFeatureVector();
    try {
        Optional<AbstractModel> model = ModelFactory.createFromReader(reader);
        assertTrue(model.isPresent());
        float s = model.get().scoreItem(fv);
        assertEquals(s, 6.0f, 1e-10f);
    } catch (IOException e) {
        assertTrue("Could not read", false);
    }
}
Also used : FeatureVector(com.airbnb.aerosolve.core.FeatureVector) ArrayList(java.util.ArrayList) IOException(java.io.IOException) CharArrayWriter(java.io.CharArrayWriter) BufferedWriter(java.io.BufferedWriter) ModelHeader(com.airbnb.aerosolve.core.ModelHeader) StringReader(java.io.StringReader) BufferedReader(java.io.BufferedReader) ModelRecord(com.airbnb.aerosolve.core.ModelRecord) Test(org.junit.Test)

Example 39 with ModelRecord

use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.

the class AdditiveModelTest method testLoad.

@Test
public void testLoad() {
    CharArrayWriter charWriter = new CharArrayWriter();
    BufferedWriter writer = new BufferedWriter(charWriter);
    ModelHeader header = new ModelHeader();
    header.setModelType("additive");
    header.setNumRecords(4);
    ArrayList<Double> ws = new ArrayList<Double>();
    ws.add(5.0);
    ws.add(10.0);
    ws.add(-20.0);
    ArrayList<Double> wl = new ArrayList<Double>();
    wl.add(1.0);
    wl.add(2.0);
    ModelRecord record1 = new ModelRecord();
    record1.setModelHeader(header);
    ModelRecord record2 = new ModelRecord();
    record2.setFunctionForm(FunctionForm.Spline);
    record2.setFeatureFamily("spline_float");
    record2.setFeatureName("aaa");
    record2.setWeightVector(ws);
    record2.setMinVal(1.0);
    record2.setMaxVal(3.0);
    ModelRecord record3 = new ModelRecord();
    record3.setFunctionForm(FunctionForm.Spline);
    record3.setFeatureFamily("spline_string");
    record3.setFeatureName("bbb");
    record3.setWeightVector(ws);
    record3.setMinVal(1.0);
    record3.setMaxVal(2.0);
    ModelRecord record4 = new ModelRecord();
    record4.setFunctionForm(FunctionForm.Linear);
    record4.setFeatureFamily("linear_float");
    record4.setFeatureName("ccc");
    record4.setWeightVector(wl);
    ModelRecord record5 = new ModelRecord();
    record5.setFunctionForm(FunctionForm.Linear);
    record5.setFeatureFamily("linear_string");
    record5.setFeatureName("ddd");
    record5.setWeightVector(wl);
    try {
        writer.write(Util.encode(record1) + "\n");
        writer.write(Util.encode(record2) + "\n");
        writer.write(Util.encode(record3) + "\n");
        writer.write(Util.encode(record4) + "\n");
        writer.write(Util.encode(record5) + "\n");
        writer.close();
    } catch (IOException e) {
        assertTrue("Could not write", false);
    }
    String serialized = charWriter.toString();
    assertTrue(serialized.length() > 0);
    StringReader strReader = new StringReader(serialized);
    BufferedReader reader = new BufferedReader(strReader);
    FeatureVector featureVector = makeFeatureVector(2.0f, 7.0f);
    try {
        Optional<AbstractModel> model = ModelFactory.createFromReader(reader);
        assertTrue(model.isPresent());
        float score = model.get().scoreItem(featureVector);
        assertEquals(8.0f + 10.0f + 15.0f, score, 0.001f);
    } catch (IOException e) {
        assertTrue("Could not read", false);
    }
}
Also used : FeatureVector(com.airbnb.aerosolve.core.FeatureVector) ArrayList(java.util.ArrayList) IOException(java.io.IOException) CharArrayWriter(java.io.CharArrayWriter) BufferedWriter(java.io.BufferedWriter) ModelHeader(com.airbnb.aerosolve.core.ModelHeader) StringReader(java.io.StringReader) BufferedReader(java.io.BufferedReader) ModelRecord(com.airbnb.aerosolve.core.ModelRecord) Test(org.junit.Test)

Example 40 with ModelRecord

use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.

the class SupportVectorTest method testSerialization.

@Test
public void testSerialization() {
    FloatVector v1 = new FloatVector(new float[] { 1.0f, 2.0f });
    FloatVector v2 = new FloatVector(new float[] { 3.0f, 5.0f });
    SupportVector tmp = new SupportVector(v1, FunctionForm.ARC_COSINE, 0.1f, 0.5f);
    ModelRecord rec = tmp.toModelRecord();
    SupportVector sv = new SupportVector(rec);
    assertEquals(0.5, sv.evaluate(v1), 0.01f);
    double expected = 1.0 - (float) Math.acos((3.0 + 10.0) / Math.sqrt((1.0 + 4.0) * (9.0 + 25.0))) / Math.PI;
    assertEquals(0.5 * expected, sv.evaluate(v2), 0.01f);
}
Also used : ModelRecord(com.airbnb.aerosolve.core.ModelRecord) Test(org.junit.Test)

Aggregations

ModelRecord (com.airbnb.aerosolve.core.ModelRecord)40 ModelHeader (com.airbnb.aerosolve.core.ModelHeader)14 Test (org.junit.Test)11 ArrayList (java.util.ArrayList)10 FloatVector (com.airbnb.aerosolve.core.util.FloatVector)8 Map (java.util.Map)5 FeatureVector (com.airbnb.aerosolve.core.FeatureVector)4 BufferedReader (java.io.BufferedReader)3 BufferedWriter (java.io.BufferedWriter)3 CharArrayWriter (java.io.CharArrayWriter)3 IOException (java.io.IOException)3 StringReader (java.io.StringReader)3 LabelDictionaryEntry (com.airbnb.aerosolve.core.LabelDictionaryEntry)2 AbstractFunction (com.airbnb.aerosolve.core.function.AbstractFunction)2 Function (com.airbnb.aerosolve.core.function.Function)2 NDTreeModelTest (com.airbnb.aerosolve.core.models.NDTreeModelTest)2 SupportVector (com.airbnb.aerosolve.core.util.SupportVector)2 HashMap (java.util.HashMap)2 DebugScoreRecord (com.airbnb.aerosolve.core.DebugScoreRecord)1 MulticlassScoringResult (com.airbnb.aerosolve.core.MulticlassScoringResult)1