use of ai.djl.modality.nlp.embedding.TextEmbedding in project djl by deepjavalibrary.
the class TrainSeq2Seq method getDataset.
public static TextDataset getDataset(Dataset.Usage usage, Arguments arguments, TextEmbedding sourceEmbedding, TextEmbedding targetEmbedding) throws IOException, TranslateException {
long limit = usage == Dataset.Usage.TRAIN ? arguments.getLimit() : arguments.getLimit() / 10;
TatoebaEnglishFrenchDataset.Builder datasetBuilder = TatoebaEnglishFrenchDataset.builder().setSampling(arguments.getBatchSize(), true, false).optDataBatchifier(PaddingStackBatchifier.builder().optIncludeValidLengths(true).addPad(0, 0, (m) -> m.zeros(new Shape(1)), 10).build()).optLabelBatchifier(PaddingStackBatchifier.builder().optIncludeValidLengths(true).addPad(0, 0, (m) -> m.ones(new Shape(1)), 10).build()).optUsage(usage).optPrefetchNumber(8).optLimit(limit);
Configuration sourceConfig = new Configuration().setTextProcessors(Arrays.asList(new SimpleTokenizer(), new LowerCaseConvertor(Locale.ENGLISH), new PunctuationSeparator(), new TextTruncator(10)));
Configuration targetConfig = new Configuration().setTextProcessors(Arrays.asList(new SimpleTokenizer(), new LowerCaseConvertor(Locale.FRENCH), new PunctuationSeparator(), new TextTruncator(8), new TextTerminator()));
if (sourceEmbedding != null) {
sourceConfig.setTextEmbedding(sourceEmbedding);
} else {
sourceConfig.setEmbeddingSize(32);
}
if (targetEmbedding != null) {
targetConfig.setTextEmbedding(targetEmbedding);
} else {
targetConfig.setEmbeddingSize(32);
}
TatoebaEnglishFrenchDataset dataset = datasetBuilder.setSourceConfiguration(sourceConfig).setTargetConfiguration(targetConfig).build();
dataset.prepare(new ProgressBar());
return dataset;
}
Aggregations