use of com.airbnb.aerosolve.core.function.Linear in project aerosolve by airbnb.
the class AdditiveModelTest method makeAdditiveModel.
AdditiveModel makeAdditiveModel() {
AdditiveModel model = new AdditiveModel();
Map<String, Map<String, Function>> weights = new HashMap<>();
Map<String, Function> innerSplineFloat = new HashMap<String, Function>();
Map<String, Function> innerLinearFloat = new HashMap<String, Function>();
Map<String, Function> innerSplineString = new HashMap<String, Function>();
Map<String, Function> innerLinearString = new HashMap<String, Function>();
weights.put("spline_float", innerSplineFloat);
weights.put("linear_float", innerLinearFloat);
weights.put("spline_string", innerSplineString);
weights.put("linear_string", innerLinearString);
float[] ws = { 5.0f, 10.0f, -20.0f };
innerSplineFloat.put("aaa", new Spline(1.0f, 3.0f, ws));
// for string feature, only the first element in weight is meaningful.
innerSplineString.put("bbb", new Spline(1.0f, 2.0f, ws));
float[] wl = { 1.0f, 2.0f };
innerLinearFloat.put("ccc", new Linear(-10.0f, 5.0f, wl));
innerLinearString.put("ddd", new Linear(1.0f, 1.0f, wl));
model.setWeights(weights);
model.setOffset(0.5f);
model.setSlope(1.5f);
return model;
}
use of com.airbnb.aerosolve.core.function.Linear in project aerosolve by airbnb.
the class AdditiveModelTest method testAddFunction.
@Test
public void testAddFunction() {
AdditiveModel model = makeAdditiveModel();
// add an existing feature without overwrite
model.addFunction("spline_float", "aaa", new Spline(2.0f, 10.0f, 5), false);
// add an existing feature with overwrite
model.addFunction("linear_float", "ccc", new Linear(3.0f, 5.0f), true);
// add a new feature
model.addFunction("spline_float", "new", new Spline(2.0f, 10.0f, 5), false);
Map<String, Map<String, Function>> weights = model.getWeights();
for (Map.Entry<String, Map<String, Function>> featureFamily : weights.entrySet()) {
String familyName = featureFamily.getKey();
Map<String, Function> features = featureFamily.getValue();
for (Map.Entry<String, Function> feature : features.entrySet()) {
String featureName = feature.getKey();
Function func = feature.getValue();
if (familyName.equals("spline_float")) {
Spline spline = (Spline) func;
if (featureName.equals("aaa")) {
assertTrue(spline.getMaxVal() == 3.0f);
assertTrue(spline.getMinVal() == 1.0f);
assertTrue(spline.getWeights().length == 3);
} else if (featureName.equals("new")) {
assertTrue(spline.getMaxVal() == 10.0f);
assertTrue(spline.getMinVal() == 2.0f);
assertTrue(spline.getWeights().length == 5);
}
} else if (familyName.equals("linear_float") && featureName.equals("ccc")) {
Linear linear = (Linear) func;
assertTrue(linear.getWeights().length == 2);
assertTrue(linear.getWeights()[0] == 0.0f);
assertTrue(linear.getWeights()[1] == 0.0f);
assertTrue(linear.getMinVal() == 3.0f);
assertTrue(linear.getMaxVal() == 5.0f);
}
}
}
}