use of org.apache.spark.sql.delta.util.SetAccumulator in project parent by Daytime-Don-t-Know-Dark-Night.
the class YouData method replace.
public static void replace(Dataset<Row> ds, String uri, JdbcOptionsInWrite opts, Column part) throws SQLException, ExecutionException, SparkException, InvocationTargetException {
final double commit_factor = 0.1;
final double max_allowed_factor = 0.3;
final double max_memory_factor = 0.1;
JdbcDialect dialect = JdbcDialects.get(opts.url());
long max_allowed_packet = 4 * 1024 * 1024;
long buffer_pool_size = 128 * 1024 * 1024;
try (Connection conn = JdbcUtils.createConnectionFactory(opts).apply();
Statement statement = conn.createStatement()) {
try (ResultSet packetRes = statement.executeQuery("show global variables like 'max_allowed_packet'")) {
while (packetRes.next()) {
max_allowed_packet = packetRes.getLong("Value");
}
}
try (ResultSet bufferRes = statement.executeQuery("show global variables like 'innodb_buffer_pool_size'")) {
while (bufferRes.next()) {
buffer_pool_size = bufferRes.getLong("Value");
}
}
}
StructType schema = ds.schema();
String sql_ = JdbcUtils.getInsertStatement(opts.table(), schema, Option.empty(), true, dialect);
// sql拼接时不使用''的类型
List<DataType> specialType = ImmutableList.of(DataTypes.BooleanType, DataTypes.LongType, DataTypes.IntegerType);
MemoryMXBean memoryMXBean = ManagementFactory.getMemoryMXBean();
// 堆内存使用情况
MemoryUsage memoryUsage = memoryMXBean.getHeapMemoryUsage();
// 最大可用内存
long maxMemorySize = memoryUsage.getMax();
SparkContext sparkContext = ds.sparkSession().sparkContext();
int cores = sparkContext.defaultParallelism();
int executorNum = sparkContext.getExecutorIds().size() + 1;
long partNum = Math.min(Math.round(max_allowed_packet * max_allowed_factor), Math.round(maxMemorySize * max_memory_factor * executorNum / cores));
long bufferLength = Math.round(buffer_pool_size * commit_factor);
Preconditions.checkArgument(partNum > 0, "partNum计算值<=0");
Set<Object> lastSet = Sets.newHashSet();
SetAccumulator<Object> collectionAc = new SetAccumulator<>();
collectionAc.register(SparkSession.active().sparkContext(), Option.apply("setAc"), false);
SparkSession.active().sparkContext().addSparkListener(new SparkListener() {
public synchronized void onTaskEnd(SparkListenerTaskEnd taskEnd) {
// 上次set中的值 lastSet, 本次set中的值 collectionAc.value()
synchronized (lastSet) {
Set<Object> diffSet;
synchronized (collectionAc.value()) {
diffSet = Sets.difference(collectionAc.value(), lastSet).immutableCopy();
}
if (!diffSet.isEmpty()) {
for (Object i : diffSet) {
System.out.println("监听器检测到差异..." + i);
}
lastSet.addAll(diffSet);
}
}
}
});
StructType structType = ds.schema();
if (Objects.nonNull(part)) {
ds = ds.withColumn("partNum", part).repartition(col("partNum")).sortWithinPartitions(col("partNum"));
} else {
ds = ds.withColumn("partNum", expr("'0'"));
}
ds.foreachPartition(rows -> {
try (Connection conn = JdbcUtils.createConnectionFactory(opts).apply();
Statement statement = conn.createStatement()) {
conn.setAutoCommit(false);
int numFields = schema.fields().length;
int executeLength = 0;
StringBuilder sb = null;
String sqlPrefix = null;
Object lastPartNum = null;
while (rows.hasNext()) {
Row row = rows.next();
// 如果本行数据和上一行数据在同一张表, 直接插入数据, 如果本行数据和上一行数据不在同一张表, 提交上一次数据, 建表, 插入数据
Object tmpPartNum = row.getAs("partNum");
Preconditions.checkArgument(Objects.nonNull(tmpPartNum));
if (!Objects.equals(tmpPartNum, lastPartNum)) {
String newTableName = opts.table() + partSuffix(tmpPartNum);
// 如果有该表则将该表中的数据清空
String drop_table = String.format("drop table if exists %s", newTableName);
statement.executeUpdate(drop_table);
// 建表
String col_sql = tableCol(structType);
col_sql = col_sql.substring(0, col_sql.length() - 1);
String create_table = String.format("create table if not exists %s(%s) DEFAULT CHARSET=utf8mb4", newTableName, col_sql);
statement.executeUpdate(create_table);
sqlPrefix = sql_.substring(0, sql_.indexOf("(?"));
// 替换表名
sqlPrefix = sqlPrefix.replace(sqlPrefix.substring(12, sqlPrefix.indexOf("(")), newTableName);
if (!Objects.isNull(lastPartNum)) {
// 记录执行掉的sql的长度
executeLength += sb.length();
statement.executeUpdate(sb.substring(0, sb.length() - 1));
sb.setLength(0);
sb.append(sqlPrefix);
} else {
sb = new StringBuilder((int) partNum + 1000);
sb.append(sqlPrefix);
}
collectionAc.add(tmpPartNum);
lastPartNum = tmpPartNum;
}
StringBuilder group = new StringBuilder("(");
for (int i = 0; i < numFields; i++) {
DataType type = schema.apply(i).dataType();
if (row.isNullAt(i)) {
// null值处理
group.append("null,");
} else if (specialType.contains(type)) {
// 判断该类型数据是否需要''
Object tmp = row.getAs(i);
group.append(tmp).append(',');
} else if (type == DataTypes.StringType) {
// 如果该类型为字符串类型且包含', 则对'进行转义
String tmp = row.getAs(i);
group.append("'").append(tmp.replaceAll("'", "''")).append("',");
} else {
Object tmp = row.getAs(i);
group.append("'").append(tmp).append("',");
}
}
group.delete(group.length() - 1, group.length());
group.append("),");
sb.append(group);
if (sb.length() * 2L >= partNum) {
// 每执行一次, 累计 + sb.length
executeLength += sb.length();
statement.executeLargeUpdate(sb.delete(sb.length() - 1, sb.length()).toString());
sb.setLength(0);
sb.append(sqlPrefix);
}
// 上面每执行一次, 累计 + max_allowed_packet, 累计加到缓冲池的阈值, 提交
if (executeLength >= bufferLength) {
logger.info("commit执行时间: {}", Instant.now());
conn.commit();
executeLength = 0;
}
}
// 剩余还有未被执行的数据
if (Objects.nonNull(sb) && sb.length() > sqlPrefix.length()) {
String ex_sql = sb.substring(0, sb.length() - 1);
statement.executeUpdate(ex_sql);
sb.setLength(0);
sb.append(sqlPrefix);
}
{
logger.info("commit执行时间: {}", Instant.now());
conn.commit();
}
}
});
}
use of org.apache.spark.sql.delta.util.SetAccumulator in project parent by Daytime-Don-t-Know-Dark-Night.
the class YouData2 method replace.
public static void replace(Dataset<Row> ds, String uri, JdbcOptionsInWrite opts, int partNum) throws SQLException, ExecutionException, SparkException, InvocationTargetException {
final double commit_factor = 0.1;
final double max_allowed_factor = 0.3;
final double max_memory_factor = 0.1;
JdbcDialect dialect = JdbcDialects.get(opts.url());
long max_allowed_packet = 4 * 1024 * 1024;
long buffer_pool_size = 128 * 1024 * 1024;
try (Connection conn = JdbcUtils.createConnectionFactory(opts).apply();
Statement statement = conn.createStatement()) {
for (int i = 0; i < partNum; i++) {
String newTableName = opts.table() + partSuffix(i);
try {
// 截断
String truncate_table = String.format("truncate table %s", newTableName);
statement.executeUpdate(truncate_table);
// drop
String drop_table = String.format("drop table if exists %s", newTableName);
statement.executeUpdate(drop_table);
} catch (SQLSyntaxErrorException e) {
e.printStackTrace();
}
// 建表
String col_sql = tableCol(ds.schema());
col_sql = col_sql.substring(0, col_sql.length() - 1);
String create_table = String.format("create table if not exists %s(%s) DEFAULT CHARSET=utf8mb4", newTableName, col_sql);
statement.executeUpdate(create_table);
}
try (ResultSet packetRes = statement.executeQuery("show global variables like 'max_allowed_packet'")) {
while (packetRes.next()) {
max_allowed_packet = packetRes.getLong("Value");
}
}
try (ResultSet bufferRes = statement.executeQuery("show global variables like 'innodb_buffer_pool_size'")) {
while (bufferRes.next()) {
buffer_pool_size = bufferRes.getLong("Value");
}
}
}
StructType schema = ds.schema();
String sql_ = JdbcUtils.getInsertStatement(opts.table(), schema, Option.empty(), true, dialect);
// sql拼接时不使用''的类型
List<DataType> specialType = ImmutableList.of(DataTypes.BooleanType, DataTypes.LongType, DataTypes.IntegerType);
MemoryMXBean memoryMXBean = ManagementFactory.getMemoryMXBean();
// 堆内存使用情况
MemoryUsage memoryUsage = memoryMXBean.getHeapMemoryUsage();
// 最大可用内存
long maxMemorySize = memoryUsage.getMax();
SparkContext sparkContext = ds.sparkSession().sparkContext();
int cores = sparkContext.defaultParallelism();
int executorNum = sparkContext.getExecutorIds().size() + 1;
long partLimit = Math.min(Math.round(max_allowed_packet * max_allowed_factor), Math.round(maxMemorySize * max_memory_factor * executorNum / cores));
long bufferLength = Math.round(buffer_pool_size * commit_factor);
Preconditions.checkArgument(partLimit > 0, "partLimit计算值<=0");
// 重要参数
logger.info("线程数: {}, 最大可用内存: {}, executorNum: {}, partLimit: {}, bufferLength: {}", cores, maxMemorySize, executorNum, partLimit, bufferLength);
SetAccumulator<Integer> collectionAc = new SetAccumulator<>();
collectionAc.register(SparkSession.active().sparkContext(), Option.apply("setAc"), false);
Map<Integer, Integer> partitionMap = Maps.newHashMap();
SparkSession.active().sparkContext().addSparkListener(new SparkListener() {
public void onStageSubmitted(SparkListenerStageSubmitted stageSubmitted) {
int stageId = stageSubmitted.stageInfo().stageId();
int numTasks = stageSubmitted.stageInfo().numTasks();
partitionMap.put(stageId, numTasks);
}
public void onTaskEnd(SparkListenerTaskEnd taskEnd) {
int currStageId = taskEnd.stageId();
Preconditions.checkArgument(partitionMap.containsKey(currStageId), "当前task中的stageId: " + currStageId + ", map中: " + partitionMap);
int partitions = partitionMap.get(currStageId);
synchronized (collectionAc.value()) {
// 确定该发哪一个
Set<Integer> currSet = collectionAc.value();
long index = taskEnd.taskInfo().index();
long currIndex = index % partNum;
// partitionId % partNum = i的个数
long num1 = Stream.iterate(0, k -> k + 1).limit(partitions).filter(j -> j % partNum == currIndex).count();
// currSet中的值 % partNum = i的个数
long num2 = currSet.stream().filter(j -> j % partNum == currIndex).count();
if (num1 == num2) {
// 证明累加器中 % partNum = i的值已经集齐
logger.info("累加器中partNum={}的数据已经导出完成", currIndex);
collectionAc.value().removeIf(k -> k % partNum == currIndex);
}
}
}
});
ds.foreachPartition(rows -> {
int taskId = TaskContext.getPartitionId();
// 0-199
collectionAc.add(taskId);
long tmpPartNum = taskId % partNum;
try (Connection conn = JdbcUtils.createConnectionFactory(opts).apply();
Statement statement = conn.createStatement()) {
conn.setAutoCommit(false);
int numFields = schema.fields().length;
long executeLength = 0;
// 替换表名
String newTableName = opts.table() + partSuffix(tmpPartNum);
String sqlPrefix = sql_.substring(0, sql_.indexOf("(?"));
sqlPrefix = sqlPrefix.replace(sqlPrefix.substring(12, sqlPrefix.indexOf("(")), newTableName);
StringBuilder sb = new StringBuilder((int) partLimit + 1000);
sb.append(sqlPrefix);
while (rows.hasNext()) {
Row row = rows.next();
StringBuilder group = new StringBuilder("(");
for (int i = 0; i < numFields; i++) {
DataType type = schema.apply(i).dataType();
if (row.isNullAt(i)) {
// null值处理
group.append("null,");
} else if (specialType.contains(type)) {
// 判断该类型数据是否需要''
Object tmp = row.getAs(i);
group.append(tmp).append(',');
} else if (type == DataTypes.StringType) {
// 如果该类型为字符串类型且包含', 则对'进行转义
String tmp = row.getAs(i);
group.append("'").append(tmp.replaceAll("'", "''")).append("',");
} else {
Object tmp = row.getAs(i);
group.append("'").append(tmp).append("',");
}
}
group.delete(group.length() - 1, group.length());
group.append("),");
sb.append(group);
if (sb.length() * 2L >= partLimit) {
// 每执行一次, 累计 + sb.length
executeLength += sb.length();
logger.info("任务ID为: {}, execute过的数据长度: {}", taskId, executeLength);
statement.executeLargeUpdate(sb.delete(sb.length() - 1, sb.length()).toString());
sb.setLength(0);
sb.append(sqlPrefix);
// 每次execute过后, 查看数据库缓冲池的可用页数量, 如果可用页数量<1000, commit
int bufferPageFreeNum = -1;
ResultSet bufferPageRes = statement.executeQuery("show status like 'Innodb_buffer_pool_pages_free'");
while (bufferPageRes.next()) {
bufferPageFreeNum = bufferPageRes.getInt("Value");
}
if (bufferPageFreeNum > 0 && bufferPageFreeNum < 1000) {
logger.info("缓冲池剩余空闲页为: {}, 执行commit: {}", bufferPageFreeNum, Instant.now());
conn.commit();
executeLength = 0;
}
}
}
// 剩余还有未被执行的数据
if (sb.length() > sqlPrefix.length()) {
String ex_sql = sb.substring(0, sb.length() - 1);
logger.info("execute过的数据长度: {}", executeLength);
statement.executeUpdate(ex_sql);
sb.setLength(0);
sb.append(sqlPrefix);
}
{
logger.info("commit执行时间: {}", Instant.now());
conn.commit();
}
}
});
}
use of org.apache.spark.sql.delta.util.SetAccumulator in project parent by Daytime-Don-t-Know-Dark-Night.
the class Listen method main.
// Spark 监听器
public static void main(String[] args) {
SparkSession spark = SparkSession.builder().master("local[*]").getOrCreate();
StructType schema = new StructType().add("date", "date").add("pay_no", "string").add("amount", "double").add("pay_time", "timestamp");
Row row1 = RowFactory.create(Date.valueOf("2021-11-11"), "001", 1.0, Timestamp.from(Instant.now()));
Row row2 = RowFactory.create(Date.valueOf("2021-11-11"), "002", 2.0, Timestamp.from(Instant.now()));
Row row3 = RowFactory.create(Date.valueOf("2021-11-12"), "003", 3.0, Timestamp.from(Instant.now()));
Row row4 = RowFactory.create(Date.valueOf("2021-11-12"), "004", 4.0, Timestamp.from(Instant.now()));
Row row5 = RowFactory.create(Date.valueOf("2021-11-13"), "005", 5.0, Timestamp.from(Instant.now()));
Dataset<Row> ds = spark.createDataFrame(ImmutableList.of(row1, row2, row3, row4, row5), schema);
Set<Object> lastSet = Sets.newHashSet();
SetAccumulator<Object> collectionAc = new SetAccumulator<>();
collectionAc.register(SparkSession.active().sparkContext(), Option.apply("setAc"), false);
SparkSession.active().sparkContext().addSparkListener(new SparkListener() {
public void onTaskEnd(SparkListenerTaskEnd taskEnd) {
// 上次set中的值 lastSet
// 本次set中的值 collectionAc.value()
Set<Object> diffSet = Sets.difference(collectionAc.value(), lastSet);
if (!diffSet.isEmpty()) {
for (Object obj : diffSet) {
System.out.println("监听器检测到差异..." + obj);
}
}
}
});
ds.repartition(col("date")).sortWithinPartitions(col("date")).foreachPartition(rows -> {
Object a = null;
while (rows.hasNext()) {
Row row = rows.next();
Object tmpPartNum = row.getAs("date");
if (!Objects.equals(tmpPartNum, a)) {
collectionAc.add(tmpPartNum);
a = tmpPartNum;
}
}
});
ds.show(false);
}
Aggregations