use of ai.djl.repository.Artifact in project djl by deepjavalibrary.
the class CookingStackExchange method prepare.
/**
* {@inheritDoc}
*/
@Override
public void prepare(Progress progress) throws IOException {
if (prepared) {
return;
}
Artifact artifact = mrl.getDefaultArtifact();
mrl.prepare(artifact, progress);
Artifact.Item item;
switch(usage) {
case TRAIN:
item = artifact.getFiles().get("train");
break;
case TEST:
item = artifact.getFiles().get("test");
break;
case VALIDATION:
default:
throw new IOException("Only training and testing dataset supported.");
}
root = mrl.getRepository().getFile(item, "").toAbsolutePath();
prepared = true;
}
use of ai.djl.repository.Artifact in project djl by deepjavalibrary.
the class StanfordMovieReview method prepare.
/**
* {@inheritDoc}
*/
@Override
public void prepare(Progress progress) throws IOException, EmbeddingException {
if (prepared) {
return;
}
Artifact artifact = mrl.getDefaultArtifact();
mrl.prepare(artifact, progress);
Path cacheDir = mrl.getRepository().getCacheDirectory();
URI resourceUri = artifact.getResourceUri();
Path root = cacheDir.resolve(resourceUri.getPath()).resolve("aclImdb").resolve("aclImdb");
Path usagePath;
switch(usage) {
case TRAIN:
usagePath = Paths.get("train");
break;
case TEST:
usagePath = Paths.get("test");
break;
case VALIDATION:
default:
throw new UnsupportedOperationException("Validation data not available.");
}
usagePath = root.resolve(usagePath);
List<String> reviewTexts = new ArrayList<>();
reviewSentiments = new ArrayList<>();
reviewImdbScore = new ArrayList<>();
prepareDataSentiment(usagePath.resolve("pos"), true, reviewTexts);
prepareDataSentiment(usagePath.resolve("neg"), false, reviewTexts);
preprocess(reviewTexts, true);
prepared = true;
}
use of ai.djl.repository.Artifact in project djl by deepjavalibrary.
the class AmesRandomAccess method prepare.
/**
* {@inheritDoc}
*/
@Override
public void prepare(Progress progress) throws IOException {
if (prepared) {
return;
}
Artifact artifact = mrl.getDefaultArtifact();
mrl.prepare(artifact, progress);
Path dir = mrl.getRepository().getResourceDirectory(artifact);
Path root = dir.resolve("house-prices-advanced-regression-techniques");
Path csvFile;
switch(usage) {
case TRAIN:
csvFile = root.resolve("train.csv");
break;
case TEST:
csvFile = root.resolve("test.csv");
break;
case VALIDATION:
default:
throw new UnsupportedOperationException("Validation data not available.");
}
csvUrl = csvFile.toUri().toURL();
super.prepare(progress);
prepared = true;
}
use of ai.djl.repository.Artifact in project djl by deepjavalibrary.
the class Cifar10 method prepare.
/**
* {@inheritDoc}
*/
@Override
public void prepare(Progress progress) throws IOException {
if (prepared) {
return;
}
Artifact artifact = mrl.getDefaultArtifact();
mrl.prepare(artifact, progress);
Map<String, Artifact.Item> map = artifact.getFiles();
Artifact.Item item;
switch(usage) {
case TRAIN:
item = map.get("data_batch.bin");
break;
case TEST:
item = map.get("test_batch.bin");
break;
case VALIDATION:
default:
throw new UnsupportedOperationException("Validation data not available.");
}
NDArray dataAndLabels = readData(item);
data = new NDArray[] { dataAndLabels.get(":, 1:").reshape(-1, 3, IMAGE_HEIGHT, IMAGE_WIDTH).transpose(0, 2, 3, 1) };
labels = new NDArray[] { dataAndLabels.get(":,0") };
// check if data and labels have the same size
if (data[0].size(0) != labels[0].size(0)) {
throw new IOException("the size of data " + data[0].size(0) + " didn't match with the size of labels " + labels[0].size(0));
}
prepared = true;
}
use of ai.djl.repository.Artifact in project djl by deepjavalibrary.
the class FashionMnist method prepare.
/**
* {@inheritDoc}
*/
@Override
public void prepare(Progress progress) throws IOException {
if (prepared) {
return;
}
Artifact artifact = mrl.getDefaultArtifact();
mrl.prepare(artifact, progress);
Map<String, Artifact.Item> map = artifact.getFiles();
Artifact.Item imageItem;
Artifact.Item labelItem;
switch(usage) {
case TRAIN:
imageItem = map.get("train_data");
labelItem = map.get("train_labels");
break;
case TEST:
imageItem = map.get("test_data");
labelItem = map.get("test_labels");
break;
case VALIDATION:
default:
throw new UnsupportedOperationException("Validation data not available.");
}
labels = new NDArray[] { readLabel(labelItem) };
data = new NDArray[] { readData(imageItem, labels[0].size()) };
prepared = true;
}
Aggregations