use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.
the class MultiDimensionSplineTest method modelRecord1D.
@Test
public void modelRecord1D() {
MultiDimensionSpline a = getMultiDimensionSpline1D();
set1D(a);
ModelRecord record = a.toModelRecord("", "");
MultiDimensionSpline b = new MultiDimensionSpline(record);
assertEquals(0.7, a.evaluate(3.5f), 0.0001);
assertEquals(0.7, b.evaluate(3.5f), 0.0001);
}
use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.
the class Linear method toModelRecord.
@Override
public ModelRecord toModelRecord(String featureFamily, String featureName) {
ModelRecord record = new ModelRecord();
record.setFunctionForm(FunctionForm.Linear);
record.setFeatureFamily(featureFamily);
record.setFeatureName(featureName);
record.setMinVal(minVal);
record.setMaxVal(maxVal);
ArrayList<Double> arrayList = new ArrayList<Double>();
arrayList.add((double) weights[0]);
arrayList.add((double) weights[1]);
record.setWeightVector(arrayList);
return record;
}
use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.
the class MlpModelTest method testLoad.
@Test
public void testLoad() {
CharArrayWriter charWriter = new CharArrayWriter();
BufferedWriter writer = new BufferedWriter(charWriter);
// create header record
ModelHeader header = new ModelHeader();
header.setModelType("multilayer_perceptron");
header.setNumHiddenLayers(1);
ArrayList<Integer> nodeNum = new ArrayList<>(2);
nodeNum.add(3);
nodeNum.add(1);
header.setNumberHiddenNodes(nodeNum);
header.setNumRecords(2);
ModelRecord record1 = new ModelRecord();
record1.setModelHeader(header);
// create records for input layer
ModelRecord record2 = new ModelRecord();
record2.setFeatureFamily("in");
record2.setFeatureName("a");
ArrayList<Double> in1 = new ArrayList<>();
in1.add(0.0);
in1.add(1.0);
in1.add(1.0);
record2.setWeightVector(in1);
ModelRecord record3 = new ModelRecord();
record3.setFeatureFamily("in");
record3.setFeatureName("b");
ArrayList<Double> in2 = new ArrayList<>();
in2.add(1.0);
in2.add(1.0);
in2.add(0.0);
record3.setWeightVector(in2);
// create records for bias
ModelRecord record4 = new ModelRecord();
ArrayList<Double> b1 = new ArrayList<>();
b1.add(0.0);
b1.add(0.0);
b1.add(0.0);
record4.setWeightVector(b1);
record4.setFunctionForm(FunctionForm.RELU);
ModelRecord record5 = new ModelRecord();
ArrayList<Double> b2 = new ArrayList<>();
b2.add(0.0);
record5.setWeightVector(b2);
record5.setFunctionForm(FunctionForm.RELU);
// create records for hidden layer
ModelRecord record6 = new ModelRecord();
ArrayList<Double> h1 = new ArrayList<>();
h1.add(0.5);
record6.setWeightVector(h1);
ModelRecord record7 = new ModelRecord();
ArrayList<Double> h2 = new ArrayList<>();
h2.add(1.0);
record7.setWeightVector(h2);
ModelRecord record8 = new ModelRecord();
ArrayList<Double> h3 = new ArrayList<>();
h3.add(2.0);
record8.setWeightVector(h3);
try {
writer.write(Util.encode(record1) + "\n");
writer.write(Util.encode(record2) + "\n");
writer.write(Util.encode(record3) + "\n");
writer.write(Util.encode(record4) + "\n");
writer.write(Util.encode(record5) + "\n");
writer.write(Util.encode(record6) + "\n");
writer.write(Util.encode(record7) + "\n");
writer.write(Util.encode(record8) + "\n");
writer.close();
} catch (IOException e) {
assertTrue("Could not write", false);
}
String serialized = charWriter.toString();
assertTrue(serialized.length() > 0);
StringReader strReader = new StringReader(serialized);
BufferedReader reader = new BufferedReader(strReader);
FeatureVector fv = makeFeatureVector();
try {
Optional<AbstractModel> model = ModelFactory.createFromReader(reader);
assertTrue(model.isPresent());
float s = model.get().scoreItem(fv);
assertEquals(s, 6.0f, 1e-10f);
} catch (IOException e) {
assertTrue("Could not read", false);
}
}
use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.
the class AdditiveModelTest method testLoad.
@Test
public void testLoad() {
CharArrayWriter charWriter = new CharArrayWriter();
BufferedWriter writer = new BufferedWriter(charWriter);
ModelHeader header = new ModelHeader();
header.setModelType("additive");
header.setNumRecords(4);
ArrayList<Double> ws = new ArrayList<Double>();
ws.add(5.0);
ws.add(10.0);
ws.add(-20.0);
ArrayList<Double> wl = new ArrayList<Double>();
wl.add(1.0);
wl.add(2.0);
ModelRecord record1 = new ModelRecord();
record1.setModelHeader(header);
ModelRecord record2 = new ModelRecord();
record2.setFunctionForm(FunctionForm.Spline);
record2.setFeatureFamily("spline_float");
record2.setFeatureName("aaa");
record2.setWeightVector(ws);
record2.setMinVal(1.0);
record2.setMaxVal(3.0);
ModelRecord record3 = new ModelRecord();
record3.setFunctionForm(FunctionForm.Spline);
record3.setFeatureFamily("spline_string");
record3.setFeatureName("bbb");
record3.setWeightVector(ws);
record3.setMinVal(1.0);
record3.setMaxVal(2.0);
ModelRecord record4 = new ModelRecord();
record4.setFunctionForm(FunctionForm.Linear);
record4.setFeatureFamily("linear_float");
record4.setFeatureName("ccc");
record4.setWeightVector(wl);
ModelRecord record5 = new ModelRecord();
record5.setFunctionForm(FunctionForm.Linear);
record5.setFeatureFamily("linear_string");
record5.setFeatureName("ddd");
record5.setWeightVector(wl);
try {
writer.write(Util.encode(record1) + "\n");
writer.write(Util.encode(record2) + "\n");
writer.write(Util.encode(record3) + "\n");
writer.write(Util.encode(record4) + "\n");
writer.write(Util.encode(record5) + "\n");
writer.close();
} catch (IOException e) {
assertTrue("Could not write", false);
}
String serialized = charWriter.toString();
assertTrue(serialized.length() > 0);
StringReader strReader = new StringReader(serialized);
BufferedReader reader = new BufferedReader(strReader);
FeatureVector featureVector = makeFeatureVector(2.0f, 7.0f);
try {
Optional<AbstractModel> model = ModelFactory.createFromReader(reader);
assertTrue(model.isPresent());
float score = model.get().scoreItem(featureVector);
assertEquals(8.0f + 10.0f + 15.0f, score, 0.001f);
} catch (IOException e) {
assertTrue("Could not read", false);
}
}
use of com.airbnb.aerosolve.core.ModelRecord in project aerosolve by airbnb.
the class SupportVectorTest method testSerialization.
@Test
public void testSerialization() {
FloatVector v1 = new FloatVector(new float[] { 1.0f, 2.0f });
FloatVector v2 = new FloatVector(new float[] { 3.0f, 5.0f });
SupportVector tmp = new SupportVector(v1, FunctionForm.ARC_COSINE, 0.1f, 0.5f);
ModelRecord rec = tmp.toModelRecord();
SupportVector sv = new SupportVector(rec);
assertEquals(0.5, sv.evaluate(v1), 0.01f);
double expected = 1.0 - (float) Math.acos((3.0 + 10.0) / Math.sqrt((1.0 + 4.0) * (9.0 + 25.0))) / Math.PI;
assertEquals(0.5 * expected, sv.evaluate(v2), 0.01f);
}
Aggregations