use of ai.djl.repository.Artifact in project djl by deepjavalibrary.
the class GoEmotions method prepare.
/**
* Prepares the dataset for use with tracked progress. In this method the TSV file will be
* parsed. All datasets will be preprocessed.
*
* @param progress the progress tracker
* @throws IOException for various exceptions depending on the dataset
*/
@Override
public void prepare(Progress progress) throws IOException, EmbeddingException {
if (prepared) {
return;
}
Artifact artifact = mrl.getDefaultArtifact();
mrl.prepare(artifact, progress);
Path root = mrl.getRepository().getResourceDirectory(artifact);
Path csvFile;
switch(usage) {
case TRAIN:
csvFile = root.resolve("train.tsv");
break;
case TEST:
csvFile = root.resolve("test.tsv");
break;
case VALIDATION:
csvFile = root.resolve("dev.tsv");
break;
default:
throw new UnsupportedOperationException("Data not available.");
}
CSVFormat csvFormat = CSVFormat.TDF.builder().setQuote(null).setHeader(HeaderEnum.class).build();
URL csvUrl = csvFile.toUri().toURL();
List<CSVRecord> csvRecords;
List<String> sourceTextData = new ArrayList<>();
try (Reader reader = new InputStreamReader(new BufferedInputStream(csvUrl.openStream()), StandardCharsets.UTF_8)) {
CSVParser csvParser = new CSVParser(reader, csvFormat);
csvRecords = csvParser.getRecords();
}
for (CSVRecord csvRecord : csvRecords) {
sourceTextData.add(csvRecord.get(0));
String[] labels = csvRecord.get(1).split(",");
int[] labelInt = new int[labels.length];
for (int i = 0; i < labels.length; i++) {
labelInt[i] = Integer.parseInt(labels[i]);
}
targetData.add(labelInt);
}
preprocess(sourceTextData, true);
prepared = true;
}
use of ai.djl.repository.Artifact in project djl by deepjavalibrary.
the class StanfordMovieReview method prepare.
/**
* {@inheritDoc}
*/
@Override
public void prepare(Progress progress) throws IOException, EmbeddingException {
if (prepared) {
return;
}
Artifact artifact = mrl.getDefaultArtifact();
mrl.prepare(artifact, progress);
Path cacheDir = mrl.getRepository().getCacheDirectory();
URI resourceUri = artifact.getResourceUri();
Path root = cacheDir.resolve(resourceUri.getPath()).resolve("aclImdb").resolve("aclImdb");
Path usagePath;
switch(usage) {
case TRAIN:
usagePath = Paths.get("train");
break;
case TEST:
usagePath = Paths.get("test");
break;
case VALIDATION:
default:
throw new UnsupportedOperationException("Validation data not available.");
}
usagePath = root.resolve(usagePath);
List<String> reviewTexts = new ArrayList<>();
reviewSentiments = new ArrayList<>();
reviewImdbScore = new ArrayList<>();
prepareDataSentiment(usagePath.resolve("pos"), true, reviewTexts);
prepareDataSentiment(usagePath.resolve("neg"), false, reviewTexts);
preprocess(reviewTexts, true);
prepared = true;
}
use of ai.djl.repository.Artifact in project djl by deepjavalibrary.
the class StanfordQuestionAnsweringDataset method prepareUsagePath.
private Path prepareUsagePath(Progress progress) throws IOException {
Artifact artifact = mrl.getDefaultArtifact();
mrl.prepare(artifact, progress);
Path root = mrl.getRepository().getResourceDirectory(artifact);
Path usagePath;
switch(usage) {
case TRAIN:
usagePath = Paths.get("train-v2.0.json");
break;
case TEST:
usagePath = Paths.get("dev-v2.0.json");
break;
case VALIDATION:
default:
throw new UnsupportedOperationException("Validation data not available.");
}
return root.resolve(usagePath);
}
use of ai.djl.repository.Artifact in project djl by deepjavalibrary.
the class UniversalDependenciesEnglishEWT method prepare.
/**
* Prepares the dataset for use with tracked progress. In this method the TXT file will be
* parsed. The texts will be added to {@code sourceTextData} and the Universal POS tags will be
* added to {@code universalPosTags}. Only {@code sourceTextData} will then be preprocessed.
*
* @param progress the progress tracker
* @throws IOException for various exceptions depending on the dataset
* @throws EmbeddingException if there are exceptions during the embedding process
*/
@Override
public void prepare(Progress progress) throws IOException, EmbeddingException {
if (prepared) {
return;
}
Artifact artifact = mrl.getDefaultArtifact();
mrl.prepare(artifact, progress);
Path root = mrl.getRepository().getResourceDirectory(artifact);
Path usagePath = null;
switch(usage) {
case TRAIN:
usagePath = Paths.get("en-ud-v2/en-ud-v2/en-ud-tag.v2.train.txt");
break;
case TEST:
usagePath = Paths.get("en-ud-v2/en-ud-v2/en-ud-tag.v2.test.txt");
break;
case VALIDATION:
usagePath = Paths.get("en-ud-v2/en-ud-v2/en-ud-tag.v2.dev.txt");
break;
default:
break;
}
usagePath = root.resolve(usagePath);
StringBuilder sourceTextDatum = new StringBuilder();
List<String> sourceTextData = new ArrayList<>();
universalPosTags = new ArrayList<>();
List<Integer> universalPosTag = new ArrayList<>();
try (BufferedReader reader = Files.newBufferedReader(usagePath)) {
String row;
while ((row = reader.readLine()) != null) {
if (("").equals(row)) {
sourceTextData.add(sourceTextDatum.toString());
universalPosTags.add(universalPosTag);
sourceTextDatum.delete(0, sourceTextDatum.length());
universalPosTag = new ArrayList<>();
continue;
}
String[] splits = row.split("\t");
if (sourceTextDatum.length() != 0) {
sourceTextDatum.append(' ');
}
sourceTextDatum.append(splits[0]);
universalPosTag.add(UniversalPosTag.valueOf(splits[1]).ordinal());
}
}
preprocess(sourceTextData, true);
prepared = true;
}
use of ai.djl.repository.Artifact in project djl by deepjavalibrary.
the class Cifar10 method prepare.
/**
* {@inheritDoc}
*/
@Override
public void prepare(Progress progress) throws IOException {
if (prepared) {
return;
}
Artifact artifact = mrl.getDefaultArtifact();
mrl.prepare(artifact, progress);
Map<String, Artifact.Item> map = artifact.getFiles();
Artifact.Item item;
switch(usage) {
case TRAIN:
item = map.get("data_batch.bin");
break;
case TEST:
item = map.get("test_batch.bin");
break;
case VALIDATION:
default:
throw new UnsupportedOperationException("Validation data not available.");
}
NDArray dataAndLabels = readData(item);
data = new NDArray[] { dataAndLabels.get(":, 1:").reshape(-1, 3, IMAGE_HEIGHT, IMAGE_WIDTH).transpose(0, 2, 3, 1) };
labels = new NDArray[] { dataAndLabels.get(":,0") };
// check if data and labels have the same size
if (data[0].size(0) != labels[0].size(0)) {
throw new IOException("the size of data " + data[0].size(0) + " didn't match with the size of labels " + labels[0].size(0));
}
prepared = true;
}
Aggregations