Skip to content

Commit

Permalink
[CH] Fix some test cases too slow (#6659)
Browse files Browse the repository at this point in the history
fix ut slow , optimize lock in queryContextManager

Co-authored-by: liuneng1994 <[email protected]>
  • Loading branch information
liuneng1994 and liuneng1994 authored Aug 1, 2024
1 parent edae5b8 commit 61bc506
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -749,8 +749,7 @@ class GlutenClickHouseMergeTreePathBasedWriteSuite
}
}

// FIXME: very slow after https://github.com/apache/incubator-gluten/pull/6558
ignore("test mergetree path based write with bucket table") {
test("test mergetree path based write with bucket table") {
val dataPath = s"$basePath/lineitem_mergetree_bucket"
clearDataPath(dataPath)

Expand All @@ -760,8 +759,8 @@ class GlutenClickHouseMergeTreePathBasedWriteSuite

sourceDF.write
.format("clickhouse")
.partitionBy("l_shipdate")
.option("clickhouse.orderByKey", "l_orderkey,l_returnflag")
.partitionBy("l_returnflag")
.option("clickhouse.orderByKey", "l_orderkey")
.option("clickhouse.primaryKey", "l_orderkey")
.option("clickhouse.numBuckets", "4")
.option("clickhouse.bucketColumnNames", "l_partkey")
Expand Down Expand Up @@ -808,13 +807,13 @@ class GlutenClickHouseMergeTreePathBasedWriteSuite
val buckets = ClickHouseTableV2.getTable(fileIndex.deltaLog).bucketOption
assert(buckets.isDefined)
assertResult(4)(buckets.get.numBuckets)
assertResult("l_orderkey,l_returnflag")(
assertResult("l_orderkey")(
buckets.get.sortColumnNames
.mkString(","))
assertResult("l_partkey")(
buckets.get.bucketColumnNames
.mkString(","))
assertResult("l_orderkey,l_returnflag")(
assertResult("l_orderkey")(
ClickHouseTableV2
.getTable(fileIndex.deltaLog)
.orderByKeyOption
Expand All @@ -827,20 +826,21 @@ class GlutenClickHouseMergeTreePathBasedWriteSuite
.get
.mkString(","))
assertResult(1)(ClickHouseTableV2.getTable(fileIndex.deltaLog).partitionColumns.size)
assertResult("l_shipdate")(
assertResult("l_returnflag")(
ClickHouseTableV2
.getTable(fileIndex.deltaLog)
.partitionColumns
.head)
val addFiles = fileIndex.matchingFiles(Nil, Nil).map(f => f.asInstanceOf[AddMergeTreeParts])

assertResult(10089)(addFiles.size)
assertResult(12)(addFiles.size)
assertResult(600572)(addFiles.map(_.rows).sum)
assertResult(4)(addFiles.count(_.partitionValues("l_shipdate").equals("1992-06-01")))
assertResult(4)(addFiles.count(_.partitionValues("l_shipdate").equals("1993-01-01")))
assertResult(4)(addFiles.count(_.partitionValues("l_shipdate").equals("1995-01-21")))
assertResult(1)(addFiles.count(
f => f.partitionValues("l_shipdate").equals("1995-01-21") && f.bucketNum.equals("00000")))
assertResult(4)(addFiles.count(_.partitionValues("l_returnflag").equals("A")))
assertResult(4)(addFiles.count(_.partitionValues("l_returnflag").equals("N")))
assertResult(4)(addFiles.count(_.partitionValues("l_returnflag").equals("R")))
assertResult(1)(
addFiles.count(
f => f.partitionValues("l_returnflag").equals("A") && f.bucketNum.equals("00000")))
}
// check part pruning effect of filter on bucket column
val df = spark.sql(s"""
Expand All @@ -855,7 +855,7 @@ class GlutenClickHouseMergeTreePathBasedWriteSuite
.flatMap(partition => partition.asInstanceOf[GlutenMergeTreePartition].partList)
.map(_.name)
.distinct
assertResult(4)(touchedParts.size)
assertResult(12)(touchedParts.size)

// test upsert on partitioned & bucketed table
upsertSourceTableAndCheck(dataPath)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -801,39 +801,37 @@ class GlutenClickHouseMergeTreeWriteSuite
}
}

// FIXME: very slow after https://github.com/apache/incubator-gluten/pull/6558
ignore("test mergetree write with bucket table") {
test("test mergetree write with bucket table") {
spark.sql(s"""
|DROP TABLE IF EXISTS lineitem_mergetree_bucket;
|""".stripMargin)

spark.sql(
s"""
|CREATE TABLE IF NOT EXISTS lineitem_mergetree_bucket
|(
| l_orderkey bigint,
| l_partkey bigint,
| l_suppkey bigint,
| l_linenumber bigint,
| l_quantity double,
| l_extendedprice double,
| l_discount double,
| l_tax double,
| l_returnflag string,
| l_linestatus string,
| l_shipdate date,
| l_commitdate date,
| l_receiptdate date,
| l_shipinstruct string,
| l_shipmode string,
| l_comment string
|)
|USING clickhouse
|PARTITIONED BY (l_shipdate)
|CLUSTERED BY (l_partkey)
|${if (sparkVersion.equals("3.2")) "" else "SORTED BY (l_orderkey, l_returnflag)"} INTO 4 BUCKETS
|LOCATION '$basePath/lineitem_mergetree_bucket'
|""".stripMargin)
spark.sql(s"""
|CREATE TABLE IF NOT EXISTS lineitem_mergetree_bucket
|(
| l_orderkey bigint,
| l_partkey bigint,
| l_suppkey bigint,
| l_linenumber bigint,
| l_quantity double,
| l_extendedprice double,
| l_discount double,
| l_tax double,
| l_returnflag string,
| l_linestatus string,
| l_shipdate date,
| l_commitdate date,
| l_receiptdate date,
| l_shipinstruct string,
| l_shipmode string,
| l_comment string
|)
|USING clickhouse
|PARTITIONED BY (l_returnflag)
|CLUSTERED BY (l_partkey)
|${if (sparkVersion.equals("3.2")) "" else "SORTED BY (l_orderkey)"} INTO 4 BUCKETS
|LOCATION '$basePath/lineitem_mergetree_bucket'
|""".stripMargin)

spark.sql(s"""
| insert into table lineitem_mergetree_bucket
Expand Down Expand Up @@ -881,7 +879,7 @@ class GlutenClickHouseMergeTreeWriteSuite
if (sparkVersion.equals("3.2")) {
assert(ClickHouseTableV2.getTable(fileIndex.deltaLog).orderByKeyOption.isEmpty)
} else {
assertResult("l_orderkey,l_returnflag")(
assertResult("l_orderkey")(
ClickHouseTableV2
.getTable(fileIndex.deltaLog)
.orderByKeyOption
Expand All @@ -890,20 +888,21 @@ class GlutenClickHouseMergeTreeWriteSuite
}
assert(ClickHouseTableV2.getTable(fileIndex.deltaLog).primaryKeyOption.isEmpty)
assertResult(1)(ClickHouseTableV2.getTable(fileIndex.deltaLog).partitionColumns.size)
assertResult("l_shipdate")(
assertResult("l_returnflag")(
ClickHouseTableV2
.getTable(fileIndex.deltaLog)
.partitionColumns
.head)
val addFiles = fileIndex.matchingFiles(Nil, Nil).map(f => f.asInstanceOf[AddMergeTreeParts])

assertResult(10089)(addFiles.size)
assertResult(12)(addFiles.size)
assertResult(600572)(addFiles.map(_.rows).sum)
assertResult(4)(addFiles.count(_.partitionValues("l_shipdate").equals("1992-06-01")))
assertResult(4)(addFiles.count(_.partitionValues("l_shipdate").equals("1993-01-01")))
assertResult(4)(addFiles.count(_.partitionValues("l_shipdate").equals("1995-01-21")))
assertResult(1)(addFiles.count(
f => f.partitionValues("l_shipdate").equals("1995-01-21") && f.bucketNum.equals("00000")))
assertResult(4)(addFiles.count(_.partitionValues("l_returnflag").equals("A")))
assertResult(4)(addFiles.count(_.partitionValues("l_returnflag").equals("N")))
assertResult(4)(addFiles.count(_.partitionValues("l_returnflag").equals("R")))
assertResult(1)(
addFiles.count(
f => f.partitionValues("l_returnflag").equals("A") && f.bucketNum.equals("00000")))
}
// check part pruning effect of filter on bucket column
val df = spark.sql(s"""
Expand All @@ -918,7 +917,7 @@ class GlutenClickHouseMergeTreeWriteSuite
.flatMap(partition => partition.asInstanceOf[GlutenMergeTreePartition].partList)
.map(_.name)
.distinct
assertResult(4)(touchedParts.size)
assertResult(12)(touchedParts.size)

// test upsert on partitioned & bucketed table
upsertSourceTableAndCheck("lineitem_mergetree_bucket")
Expand Down
18 changes: 12 additions & 6 deletions cpp-ch/local-engine/Common/ConcurrentMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ class ConcurrentMap
public:
void insert(const K & key, const V & value)
{
std::lock_guard lock{mutex};
std::unique_lock lock{mutex};
map.insert({key, value});
}

V get(const K & key)
{
std::lock_guard lock{mutex};
std::shared_lock lock{mutex};
auto it = map.find(key);
if (it == map.end())
{
Expand All @@ -44,24 +44,30 @@ class ConcurrentMap

void erase(const K & key)
{
std::lock_guard lock{mutex};
std::unique_lock lock{mutex};
map.erase(key);
}

void clear()
{
std::lock_guard lock{mutex};
std::unique_lock lock{mutex};
map.clear();
}

bool contains(const K & key)
{
std::shared_lock lock{mutex};
return map.contains(key);
}

size_t size() const
{
std::lock_guard lock{mutex};
std::shared_lock lock{mutex};
return map.size();
}

private:
std::unordered_map<K, V> map;
mutable std::mutex mutex;
mutable std::shared_mutex mutex;
};
}
17 changes: 6 additions & 11 deletions cpp-ch/local-engine/Common/QueryContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <Common/ThreadStatus.h>
#include <Common/CHUtil.h>
#include <Common/GlutenConfig.h>
#include <Common/ConcurrentMap.h>
#include <base/unit.h>
#include <sstream>
#include <iomanip>
Expand All @@ -48,8 +49,7 @@ struct QueryContext
ContextMutablePtr query_context;
};

std::unordered_map<int64_t, std::shared_ptr<QueryContext>> query_map;
std::mutex query_map_mutex;
ConcurrentMap<int64_t, std::shared_ptr<QueryContext>> query_map;

int64_t QueryContextManager::initializeQuery()
{
Expand All @@ -72,9 +72,8 @@ int64_t QueryContextManager::initializeQuery()

query_context->thread_group->memory_tracker.setSoftLimit(memory_limit);
query_context->thread_group->memory_tracker.setHardLimit(memory_limit + config.extra_memory_hard_limit);
std::lock_guard<std::mutex> lock_guard(query_map_mutex);
int64_t id = reinterpret_cast<int64_t>(query_context->thread_group.get());
query_map.emplace(id, query_context);
query_map.insert(id, query_context);
return id;
}

Expand All @@ -84,9 +83,8 @@ DB::ContextMutablePtr QueryContextManager::currentQueryContext()
{
throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "Thread group not found.");
}
std::lock_guard lock_guard(query_map_mutex);
int64_t id = reinterpret_cast<int64_t>(CurrentThread::getGroup().get());
return query_map[id]->query_context;
return query_map.get(id)->query_context;
}

void QueryContextManager::logCurrentPerformanceCounters(ProfileEvents::Counters & counters)
Expand Down Expand Up @@ -116,10 +114,9 @@ void QueryContextManager::logCurrentPerformanceCounters(ProfileEvents::Counters

size_t QueryContextManager::currentPeakMemory(int64_t id)
{
std::lock_guard lock_guard(query_map_mutex);
if (!query_map.contains(id))
throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "context released {}", id);
return query_map[id]->thread_group->memory_tracker.getPeak();
return query_map.get(id)->thread_group->memory_tracker.getPeak();
}

void QueryContextManager::finalizeQuery(int64_t id)
Expand All @@ -130,8 +127,7 @@ void QueryContextManager::finalizeQuery(int64_t id)
}
std::shared_ptr<QueryContext> context;
{
std::lock_guard lock_guard(query_map_mutex);
context = query_map[id];
context = query_map.get(id);
}
auto query_context = context->thread_status->getQueryContext();
if (!query_context)
Expand All @@ -152,7 +148,6 @@ void QueryContextManager::finalizeQuery(int64_t id)
context->thread_status.reset();
query_context.reset();
{
std::lock_guard lock_guard(query_map_mutex);
query_map.erase(id);
}
}
Expand Down

0 comments on commit 61bc506

Please sign in to comment.