Search in sources :

Example 1 with Linear

use of com.airbnb.aerosolve.core.function.Linear in project aerosolve by airbnb.

the class AdditiveModelTest method makeAdditiveModel.

AdditiveModel makeAdditiveModel() {
    AdditiveModel model = new AdditiveModel();
    Map<String, Map<String, Function>> weights = new HashMap<>();
    Map<String, Function> innerSplineFloat = new HashMap<String, Function>();
    Map<String, Function> innerLinearFloat = new HashMap<String, Function>();
    Map<String, Function> innerSplineString = new HashMap<String, Function>();
    Map<String, Function> innerLinearString = new HashMap<String, Function>();
    weights.put("spline_float", innerSplineFloat);
    weights.put("linear_float", innerLinearFloat);
    weights.put("spline_string", innerSplineString);
    weights.put("linear_string", innerLinearString);
    float[] ws = { 5.0f, 10.0f, -20.0f };
    innerSplineFloat.put("aaa", new Spline(1.0f, 3.0f, ws));
    // for string feature, only the first element in weight is meaningful.
    innerSplineString.put("bbb", new Spline(1.0f, 2.0f, ws));
    float[] wl = { 1.0f, 2.0f };
    innerLinearFloat.put("ccc", new Linear(-10.0f, 5.0f, wl));
    innerLinearString.put("ddd", new Linear(1.0f, 1.0f, wl));
    model.setWeights(weights);
    model.setOffset(0.5f);
    model.setSlope(1.5f);
    return model;
}
Also used : Function(com.airbnb.aerosolve.core.function.Function) HashMap(java.util.HashMap) HashMap(java.util.HashMap) Map(java.util.Map) Spline(com.airbnb.aerosolve.core.function.Spline) Linear(com.airbnb.aerosolve.core.function.Linear)

Example 2 with Linear

use of com.airbnb.aerosolve.core.function.Linear in project aerosolve by airbnb.

the class AdditiveModelTest method testAddFunction.

@Test
public void testAddFunction() {
    AdditiveModel model = makeAdditiveModel();
    // add an existing feature without overwrite
    model.addFunction("spline_float", "aaa", new Spline(2.0f, 10.0f, 5), false);
    // add an existing feature with overwrite
    model.addFunction("linear_float", "ccc", new Linear(3.0f, 5.0f), true);
    // add a new feature
    model.addFunction("spline_float", "new", new Spline(2.0f, 10.0f, 5), false);
    Map<String, Map<String, Function>> weights = model.getWeights();
    for (Map.Entry<String, Map<String, Function>> featureFamily : weights.entrySet()) {
        String familyName = featureFamily.getKey();
        Map<String, Function> features = featureFamily.getValue();
        for (Map.Entry<String, Function> feature : features.entrySet()) {
            String featureName = feature.getKey();
            Function func = feature.getValue();
            if (familyName.equals("spline_float")) {
                Spline spline = (Spline) func;
                if (featureName.equals("aaa")) {
                    assertTrue(spline.getMaxVal() == 3.0f);
                    assertTrue(spline.getMinVal() == 1.0f);
                    assertTrue(spline.getWeights().length == 3);
                } else if (featureName.equals("new")) {
                    assertTrue(spline.getMaxVal() == 10.0f);
                    assertTrue(spline.getMinVal() == 2.0f);
                    assertTrue(spline.getWeights().length == 5);
                }
            } else if (familyName.equals("linear_float") && featureName.equals("ccc")) {
                Linear linear = (Linear) func;
                assertTrue(linear.getWeights().length == 2);
                assertTrue(linear.getWeights()[0] == 0.0f);
                assertTrue(linear.getWeights()[1] == 0.0f);
                assertTrue(linear.getMinVal() == 3.0f);
                assertTrue(linear.getMaxVal() == 5.0f);
            }
        }
    }
}
Also used : Function(com.airbnb.aerosolve.core.function.Function) HashMap(java.util.HashMap) Map(java.util.Map) Spline(com.airbnb.aerosolve.core.function.Spline) Linear(com.airbnb.aerosolve.core.function.Linear) Test(org.junit.Test)

Aggregations

Function (com.airbnb.aerosolve.core.function.Function)2 Linear (com.airbnb.aerosolve.core.function.Linear)2 Spline (com.airbnb.aerosolve.core.function.Spline)2 HashMap (java.util.HashMap)2 Map (java.util.Map)2 Test (org.junit.Test)1