use of org.deeplearning4j.ui.components.component.ComponentDiv in project deeplearning4j by deeplearning4j.
the class EvaluationTools method getRocFromPoints.
private static Component getRocFromPoints(String title, double[][] points, long positiveCount, long negativeCount, double auc) {
double[] zeroOne = new double[] { 0.0, 1.0 };
ChartLine chartLine = new ChartLine.Builder(title, CHART_STYLE).setXMin(0.0).setXMax(1.0).setYMin(0.0).setYMax(1.0).addSeries("ROC", points[0], //points[0] is false positives -> usually plotted on x axis
points[1]).addSeries("", zeroOne, zeroOne).build();
ComponentTable ct = new ComponentTable.Builder(TABLE_STYLE).header("Field", "Value").content(new String[][] { { "AUC", String.format("%.5f", auc) }, { "Total Data Positive Count", String.valueOf(positiveCount) }, { "Total Data Negative Count", String.valueOf(negativeCount) } }).build();
ComponentDiv divLeft = new ComponentDiv(INNER_DIV_STYLE, PAD_DIV, ct, PAD_DIV, INFO_TABLE);
ComponentDiv divRight = new ComponentDiv(INNER_DIV_STYLE, chartLine);
return new ComponentDiv(OUTER_DIV_STYLE, divLeft, divRight);
}
use of org.deeplearning4j.ui.components.component.ComponentDiv in project deeplearning4j by deeplearning4j.
the class EvaluationTools method getPRCharts.
private static Component getPRCharts(String precisionRecallTitle, String prThresholdTitle, List<ROC.PrecisionRecallPoint> prPoints) {
ComponentDiv divLeft = new ComponentDiv(INNER_DIV_STYLE, getPrecisionRecallCurve(precisionRecallTitle, prPoints));
ComponentDiv divRight = new ComponentDiv(INNER_DIV_STYLE, getPrecisionRecallVsThreshold(prThresholdTitle, prPoints));
return new ComponentDiv(OUTER_DIV_STYLE, divLeft, divRight);
}
use of org.deeplearning4j.ui.components.component.ComponentDiv in project deeplearning4j by deeplearning4j.
the class EvaluationTools method rocChartToHtml.
/**
* Given a {@link ROCMultiClass} instance and (optionally) names for each class, render the ROC chart to a stand-alone
* HTML file (returned as a String)
* @param rocMultiClass ROC to render
* @param classNames Names of the classes. May be null
*/
public static String rocChartToHtml(ROCMultiClass rocMultiClass, List<String> classNames) {
long[] actualCountPositive = rocMultiClass.getCountActualPositive();
long[] actualCountNegative = rocMultiClass.getCountActualNegative();
List<Component> components = new ArrayList<>(actualCountPositive.length);
for (int i = 0; i < actualCountPositive.length; i++) {
double[][] points = rocMultiClass.getResultsAsArray(i);
String headerText = "Class " + i;
if (classNames != null && classNames.size() > i) {
headerText += " (" + classNames.get(i) + ")";
}
headerText += " vs. All";
;
Component headerDivPad = new ComponentDiv(HEADER_DIV_PAD_STYLE);
components.add(headerDivPad);
Component headerDivLeft = new ComponentDiv(HEADER_DIV_TEXT_PAD_STYLE);
Component headerDiv = new ComponentDiv(HEADER_DIV_STYLE, new ComponentText(headerText, HEADER_TEXT_STYLE));
Component c = getRocFromPoints(ROC_TITLE, points, actualCountPositive[i], actualCountNegative[i], rocMultiClass.calculateAUC(i));
Component c2 = getPRCharts(PR_TITLE, PR_THRESHOLD_TITLE, rocMultiClass.getPrecisionRecallCurve(i));
components.add(headerDivLeft);
components.add(headerDiv);
components.add(c);
components.add(c2);
}
return StaticPageUtil.renderHTML(components);
}
use of org.deeplearning4j.ui.components.component.ComponentDiv in project deeplearning4j by deeplearning4j.
the class StatsUtils method exportStatsAsHTML.
/**
* Generate and export a HTML representation (including charts, etc) of the Spark training statistics<br>
* This overload is for writing to an output stream
*
* @param sparkTrainingStats Stats to generate HTML page for
* @param maxTimelineSizeMs maximum amount of activity to show in a single timeline plot (multiple plots will be used if training exceeds this amount of time)
* @throws Exception IO errors or error generating HTML file
*/
public static void exportStatsAsHTML(SparkTrainingStats sparkTrainingStats, long maxTimelineSizeMs, OutputStream outputStream) throws Exception {
Set<String> keySet = sparkTrainingStats.getKeySet();
List<Component> components = new ArrayList<>();
StyleChart styleChart = new StyleChart.Builder().backgroundColor(Color.WHITE).width(700, LengthUnit.Px).height(400, LengthUnit.Px).build();
StyleText styleText = new StyleText.Builder().color(Color.BLACK).fontSize(20).build();
Component headerText = new ComponentText("Deeplearning4j - Spark Training Analysis", styleText);
Component header = new ComponentDiv(new StyleDiv.Builder().height(40, LengthUnit.Px).width(100, LengthUnit.Percent).build(), headerText);
components.add(header);
Set<String> keySetInclude = new HashSet<>();
for (String s : keySet) if (sparkTrainingStats.defaultIncludeInPlots(s))
keySetInclude.add(s);
Collections.addAll(components, getTrainingStatsTimelineChart(sparkTrainingStats, keySetInclude, maxTimelineSizeMs));
for (String s : keySet) {
List<EventStats> list = new ArrayList<>(sparkTrainingStats.getValue(s));
Collections.sort(list, new StartTimeComparator());
double[] x = new double[list.size()];
double[] duration = new double[list.size()];
double minDur = Double.MAX_VALUE;
double maxDur = -Double.MAX_VALUE;
for (int i = 0; i < duration.length; i++) {
x[i] = i;
duration[i] = list.get(i).getDurationMs();
minDur = Math.min(minDur, duration[i]);
maxDur = Math.max(maxDur, duration[i]);
}
Component line = new ChartLine.Builder(s, styleChart).addSeries("Duration", x, duration).setYMin(minDur == maxDur ? minDur - 1 : null).setYMax(minDur == maxDur ? minDur + 1 : null).build();
//Also build a histogram...
Component hist = null;
if (minDur != maxDur && !list.isEmpty())
hist = getHistogram(duration, 20, s, styleChart);
Component[] temp;
if (hist != null) {
temp = new Component[] { line, hist };
} else {
temp = new Component[] { line };
}
components.add(new ComponentDiv(new StyleDiv.Builder().width(100, LengthUnit.Percent).build(), temp));
//TODO this is really ugly
if (!list.isEmpty() && (list.get(0) instanceof ExampleCountEventStats || list.get(0) instanceof PartitionCountEventStats)) {
boolean exCount = list.get(0) instanceof ExampleCountEventStats;
double[] y = new double[list.size()];
double miny = Double.MAX_VALUE;
double maxy = -Double.MAX_VALUE;
for (int i = 0; i < y.length; i++) {
y[i] = (exCount ? ((ExampleCountEventStats) list.get(i)).getTotalExampleCount() : ((PartitionCountEventStats) list.get(i)).getNumPartitions());
miny = Math.min(miny, y[i]);
maxy = Math.max(maxy, y[i]);
}
String title = s + " / " + (exCount ? "Number of Examples" : "Number of Partitions");
Component line2 = new ChartLine.Builder(title, styleChart).addSeries((exCount ? "Examples" : "Partitions"), x, y).setYMin(miny == maxy ? miny - 1 : null).setYMax(miny == maxy ? miny + 1 : null).build();
//Also build a histogram...
Component hist2 = null;
if (miny != maxy)
hist2 = getHistogram(y, 20, title, styleChart);
Component[] temp2;
if (hist2 != null) {
temp2 = new Component[] { line2, hist2 };
} else {
temp2 = new Component[] { line2 };
}
components.add(new ComponentDiv(new StyleDiv.Builder().width(100, LengthUnit.Percent).build(), temp2));
}
}
String html = StaticPageUtil.renderHTML(components);
outputStream.write(html.getBytes("UTF-8"));
}
use of org.deeplearning4j.ui.components.component.ComponentDiv in project deeplearning4j by deeplearning4j.
the class StatsUtils method getTrainingStatsTimelineChart.
private static Component[] getTrainingStatsTimelineChart(SparkTrainingStats stats, Set<String> includeSet, long maxDurationMs) {
Set<Tuple3<String, String, Long>> uniqueTuples = new HashSet<>();
Set<String> machineIDs = new HashSet<>();
Set<String> jvmIDs = new HashSet<>();
Map<String, String> machineShortNames = new HashMap<>();
Map<String, String> jvmShortNames = new HashMap<>();
long earliestStart = Long.MAX_VALUE;
long latestEnd = Long.MIN_VALUE;
for (String s : includeSet) {
List<EventStats> list = stats.getValue(s);
for (EventStats e : list) {
machineIDs.add(e.getMachineID());
jvmIDs.add(e.getJvmID());
uniqueTuples.add(new Tuple3<String, String, Long>(e.getMachineID(), e.getJvmID(), e.getThreadID()));
earliestStart = Math.min(earliestStart, e.getStartTime());
latestEnd = Math.max(latestEnd, e.getStartTime() + e.getDurationMs());
}
}
int count = 0;
for (String s : machineIDs) {
machineShortNames.put(s, "PC " + count++);
}
count = 0;
for (String s : jvmIDs) {
jvmShortNames.put(s, "JVM " + count++);
}
int nLanes = uniqueTuples.size();
List<Tuple3<String, String, Long>> outputOrder = new ArrayList<>(uniqueTuples);
Collections.sort(outputOrder, new TupleComparator());
Color[] colors = getColors(includeSet.size());
Map<String, Color> colorMap = new HashMap<>();
count = 0;
for (String s : includeSet) {
colorMap.put(s, colors[count++]);
}
//Create key for charts:
List<Component> tempList = new ArrayList<>();
for (String s : includeSet) {
String key = stats.getShortNameForKey(s) + " - " + s;
tempList.add(new ComponentDiv(new StyleDiv.Builder().backgroundColor(colorMap.get(s)).width(33.3, LengthUnit.Percent).height(25, LengthUnit.Px).floatValue(StyleDiv.FloatValue.left).build(), new ComponentText(key, new StyleText.Builder().fontSize(11).build())));
}
Component key = new ComponentDiv(new StyleDiv.Builder().width(100, LengthUnit.Percent).build(), tempList);
//How many charts?
int nCharts = (int) ((latestEnd - earliestStart) / maxDurationMs);
if (nCharts < 1)
nCharts = 1;
long[] chartStartTimes = new long[nCharts];
long[] chartEndTimes = new long[nCharts];
for (int i = 0; i < nCharts; i++) {
chartStartTimes[i] = earliestStart + i * maxDurationMs;
chartEndTimes[i] = earliestStart + (i + 1) * maxDurationMs;
}
List<List<List<ChartTimeline.TimelineEntry>>> entriesByLane = new ArrayList<>();
for (int c = 0; c < nCharts; c++) {
entriesByLane.add(new ArrayList<List<ChartTimeline.TimelineEntry>>());
for (int i = 0; i < nLanes; i++) {
entriesByLane.get(c).add(new ArrayList<ChartTimeline.TimelineEntry>());
}
}
for (String s : includeSet) {
List<EventStats> list = stats.getValue(s);
for (EventStats e : list) {
if (e.getDurationMs() == 0)
continue;
long start = e.getStartTime();
long end = start + e.getDurationMs();
int chartIdx = -1;
for (int j = 0; j < nCharts; j++) {
if (start >= chartStartTimes[j] && start < chartEndTimes[j]) {
chartIdx = j;
}
}
if (chartIdx == -1)
chartIdx = nCharts - 1;
Tuple3<String, String, Long> tuple = new Tuple3<>(e.getMachineID(), e.getJvmID(), e.getThreadID());
int idx = outputOrder.indexOf(tuple);
Color c = colorMap.get(s);
// ChartTimeline.TimelineEntry entry = new ChartTimeline.TimelineEntry(null, start, end, c);
ChartTimeline.TimelineEntry entry = new ChartTimeline.TimelineEntry(stats.getShortNameForKey(s), start, end, c);
entriesByLane.get(chartIdx).get(idx).add(entry);
}
}
//Sort each lane by start time:
for (int i = 0; i < nCharts; i++) {
for (List<ChartTimeline.TimelineEntry> l : entriesByLane.get(i)) {
Collections.sort(l, new Comparator<ChartTimeline.TimelineEntry>() {
@Override
public int compare(ChartTimeline.TimelineEntry o1, ChartTimeline.TimelineEntry o2) {
return Long.compare(o1.getStartTimeMs(), o2.getStartTimeMs());
}
});
}
}
StyleChart sc = new StyleChart.Builder().width(1280, LengthUnit.Px).height(35 * nLanes + (60 + 20 + 25), LengthUnit.Px).margin(LengthUnit.Px, 60, 20, 200, //top, bottom, left, right
10).build();
List<Component> list = new ArrayList<>(nCharts);
for (int j = 0; j < nCharts; j++) {
ChartTimeline.Builder b = new ChartTimeline.Builder("Timeline: Training Activities", sc);
int i = 0;
for (List<ChartTimeline.TimelineEntry> l : entriesByLane.get(j)) {
Tuple3<String, String, Long> t3 = outputOrder.get(i);
String name = machineShortNames.get(t3._1()) + ", " + jvmShortNames.get(t3._2()) + ", Thread " + t3._3();
b.addLane(name, l);
i++;
}
list.add(b.build());
}
list.add(key);
return list.toArray(new Component[list.size()]);
}
Aggregations