Search in sources :

Example 6 with FrameSplitter

use of hex.FrameSplitter in project h2o-3 by h2oai.

the class DeepWaterAbstractIntegrationTest method testCheckpointOverwriteWithBestModel.

/*
  public ArrayList<int[]> texts2arrayOnehot(ArrayList<String> texts) {
    int maxlen = 0;
    int index = 0;
    Map<String, Integer> dict = new HashMap<>();
    dict.put(PADDING_SYMBOL, index);
    index += 1;
    for (String text : texts) {
      String[] tokens = tokenize(text);
      for (String token : tokens) {
        if (!dict.containsKey(token)) {
          dict.put(token, index);
          index += 1;
        }
      }
      int len = tokens.length;
      if (len > maxlen) maxlen = len;
    }
    System.out.println(dict);
    System.out.println("maxlen " + maxlen);
    System.out.println("dict size " + dict.size());
    Assert.assertEquals(38, maxlen);
    Assert.assertEquals(88, index);
    Assert.assertEquals(88, dict.size());

    ArrayList<int[]> array = new ArrayList<>();
    for (String text: texts) {
      ArrayList<int[]> data = tokensToArray(tokenize(text), maxlen, dict);
      System.out.println(text);
      System.out.println(" rows " + data.size() + "  cols " + data.get(0).length);
      //for (int[] x : data) {
      //  System.out.println(Arrays.toString(x));
      //}
      array.addAll(data);
    }
    return array;
  }
  */
@Test
public void testCheckpointOverwriteWithBestModel() {
    Frame tfr = null;
    DeepWaterModel dl = null;
    DeepWaterModel dl2 = null;
    Frame train = null, valid = null;
    try {
        tfr = parse_test_file("./smalldata/iris/iris.csv");
        FrameSplitter fs = new FrameSplitter(tfr, new double[] { 0.8 }, new Key[] { Key.make("train"), Key.make("valid") }, null);
        fs.compute2();
        train = fs.getResult()[0];
        valid = fs.getResult()[1];
        DeepWaterParameters parms = new DeepWaterParameters();
        parms._backend = getBackend();
        parms._train = train._key;
        parms._valid = valid._key;
        parms._epochs = 1;
        parms._response_column = "C5";
        parms._hidden = new int[] { 50, 50 };
        parms._seed = 0xdecaf;
        parms._train_samples_per_iteration = 0;
        parms._score_duty_cycle = 1;
        parms._score_interval = 0;
        parms._stopping_rounds = 0;
        parms._overwrite_with_best_model = true;
        dl = new DeepWater(parms).trainModel().get();
        double ll1 = ((ModelMetricsMultinomial) dl._output._validation_metrics).logloss();
        DeepWaterParameters parms2 = (DeepWaterParameters) parms.clone();
        parms2._epochs = 10;
        parms2._checkpoint = dl._key;
        dl2 = new DeepWater(parms2).trainModel().get();
        double ll2 = ((ModelMetricsMultinomial) dl2._output._validation_metrics).logloss();
        Assert.assertTrue(ll2 <= ll1);
    } finally {
        if (tfr != null)
            tfr.delete();
        if (dl != null)
            dl.delete();
        if (dl2 != null)
            dl2.delete();
        if (train != null)
            train.delete();
        if (valid != null)
            valid.delete();
    }
}
Also used : Frame(water.fvec.Frame) ShuffleSplitFrame(hex.splitframe.ShuffleSplitFrame) FrameSplitter(hex.FrameSplitter) ModelMetricsMultinomial(hex.ModelMetricsMultinomial)

Aggregations

FrameSplitter (hex.FrameSplitter)6 Frame (water.fvec.Frame)6 Test (org.junit.Test)3 ModelMetricsMultinomial (hex.ModelMetricsMultinomial)2 ShuffleSplitFrame (hex.splitframe.ShuffleSplitFrame)2 NFSFileVec (water.fvec.NFSFileVec)2 FrameUtils (water.util.FrameUtils)2 DeepLearningParameters (hex.deeplearning.DeepLearningModel.DeepLearningParameters)1 File (java.io.File)1 TreeMap (java.util.TreeMap)1 Key (water.Key)1