Skip to content

Commit f59eb2e

Browse files
jan-wassenbergcopybara-github
authored andcommitted
Remove multi-package support from topology
Also no longer assume equal-sized clusters PiperOrigin-RevId: 820164125
1 parent 9b6ed1a commit f59eb2e

File tree

14 files changed

+293
-481
lines changed

14 files changed

+293
-481
lines changed

io/blob_compare.cc

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
#include "util/threading_context.h"
2929
#include "hwy/aligned_allocator.h" // Span
3030
#include "hwy/base.h"
31-
#include "hwy/contrib/thread_pool/thread_pool.h"
3231
#include "hwy/timer.h"
3332

3433
namespace gcpp {
@@ -104,27 +103,31 @@ BlobVec ReserveMemory(const RangeVec& ranges, BytePtr& all_blobs, size_t& pos) {
104103
// Reads one set of blobs in parallel (helpful if in disk cache).
105104
// Aborts on error.
106105
void ReadBlobs(BlobReader& reader, const RangeVec& ranges, BlobVec& blobs,
107-
hwy::ThreadPool& pool) {
106+
ThreadingContext& ctx, size_t cluster_idx) {
108107
HWY_ASSERT(reader.Keys().size() == blobs.size());
109108
HWY_ASSERT(ranges.size() == blobs.size());
110-
pool.Run(0, blobs.size(), [&](size_t i, size_t /*thread*/) {
111-
HWY_ASSERT(ranges[i].bytes == blobs[i].size());
112-
reader.file().Read(ranges[i].offset, ranges[i].bytes, blobs[i].data());
113-
});
109+
ParallelFor(ParallelismStrategy::kWithinCluster, blobs.size(), ctx,
110+
cluster_idx, [&](size_t i, size_t /*thread*/) {
111+
HWY_ASSERT(ranges[i].bytes == blobs[i].size());
112+
reader.file().Read(ranges[i].offset, ranges[i].bytes,
113+
blobs[i].data());
114+
});
114115
}
115116

116117
// Parallelizes ReadBlobs across (two) packages, if available.
117118
void ReadBothBlobs(BlobReader& reader1, BlobReader& reader2,
118119
const RangeVec& ranges1, const RangeVec& ranges2,
119120
size_t total_bytes, BlobVec& blobs1, BlobVec& blobs2,
120-
NestedPools& pools) {
121+
ThreadingContext& ctx) {
121122
const double t0 = hwy::platform::Now();
122-
HWY_WARN("Reading %zu GiB, %zux%zu cores: ", total_bytes >> 30,
123-
pools.AllPackages().NumWorkers(), pools.Pool().NumWorkers());
124-
pools.AllPackages().Run(0, 2, [&](size_t task, size_t pkg_idx) {
125-
ReadBlobs(task ? reader2 : reader1, task ? ranges2 : ranges1,
126-
task ? blobs2 : blobs1, pools.Pool(pkg_idx));
127-
});
123+
HWY_WARN("Reading %zu GiB, %zu clusters: ", total_bytes >> 30,
124+
ctx.pools.NumClusters());
125+
ParallelFor(ParallelismStrategy::kAcrossClusters, 2, ctx, 0,
126+
[&](const size_t task, size_t cluster_idx) {
127+
ReadBlobs(task ? reader1 : reader2, task ? ranges1 : ranges2,
128+
task ? blobs1 : blobs2, ctx, cluster_idx);
129+
});
130+
128131
const double t1 = hwy::platform::Now();
129132
HWY_WARN("%.1f GB/s\n", total_bytes / (t1 - t0) * 1E-9);
130133
}
@@ -181,29 +184,23 @@ size_t BlobDifferences(const ByteSpan data1, const ByteSpan data2,
181184
}
182185

183186
void CompareBlobs(const KeyVec& keys, BlobVec& blobs1, BlobVec& blobs2,
184-
size_t total_bytes, NestedPools& pools) {
187+
size_t total_bytes, ThreadingContext& ctx) {
185188
HWY_WARN("Comparing %zu blobs in parallel: ", keys.size());
186189
const double t0 = hwy::platform::Now();
187190
std::atomic<size_t> blobs_equal{};
188191
std::atomic<size_t> blobs_diff{};
189-
const IndexRangePartition ranges = StaticPartition(
190-
IndexRange(0, keys.size()), pools.AllPackages().NumWorkers(), 1);
191-
ParallelizeOneRange(
192-
ranges, pools.AllPackages(),
193-
[&](const IndexRange& range, size_t pkg_idx) {
194-
pools.Pool(pkg_idx).Run(
195-
range.begin(), range.end(), [&](size_t i, size_t /*thread*/) {
196-
const size_t mismatches =
197-
BlobDifferences(blobs1[i], blobs2[i], keys[i]);
198-
if (mismatches != 0) {
199-
HWY_WARN("key %s has %zu mismatches in %zu bytes!\n",
200-
keys[i].c_str(), mismatches, blobs1[i].size());
201-
blobs_diff.fetch_add(1);
202-
} else {
203-
blobs_equal.fetch_add(1);
204-
}
205-
});
206-
});
192+
ParallelFor(ParallelismStrategy::kHierarchical, keys.size(), ctx, 0,
193+
[&](size_t i, size_t /*thread*/) {
194+
const size_t mismatches =
195+
BlobDifferences(blobs1[i], blobs2[i], keys[i]);
196+
if (mismatches != 0) {
197+
HWY_WARN("key %s has %zu mismatches in %zu bytes!\n",
198+
keys[i].c_str(), mismatches, blobs1[i].size());
199+
blobs_diff.fetch_add(1);
200+
} else {
201+
blobs_equal.fetch_add(1);
202+
}
203+
});
207204
const double t1 = hwy::platform::Now();
208205
HWY_WARN("%.1f GB/s; total blob matches=%zu, mismatches=%zu\n",
209206
total_bytes / (t1 - t0) * 1E-9, blobs_equal.load(),
@@ -230,9 +227,9 @@ void ReadAndCompareBlobs(const Path& path1, const Path& path2) {
230227
ThreadingArgs args;
231228
ThreadingContext ctx(args);
232229
ReadBothBlobs(reader1, reader2, ranges1, ranges2, total_bytes, blobs1, blobs2,
233-
ctx.pools);
230+
ctx);
234231

235-
CompareBlobs(reader1.Keys(), blobs1, blobs2, total_bytes, ctx.pools);
232+
CompareBlobs(reader1.Keys(), blobs1, blobs2, total_bytes, ctx);
236233
}
237234

238235
} // namespace gcpp

ops/dot_test.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1124,8 +1124,9 @@ void TestAllDot() {
11241124
MatPadding::kOdd);
11251125
std::array<DotStats, kMaxWorkers> all_stats;
11261126

1127-
ctx.pools.Cluster(0, 0).Run(
1128-
0, kReps, [&](const uint32_t rep, size_t thread) {
1127+
ParallelFor(
1128+
ParallelismStrategy::kWithinCluster, kReps, ctx, 0,
1129+
[&](size_t rep, size_t thread) {
11291130
float* HWY_RESTRICT pa = a.Row(thread);
11301131
float* HWY_RESTRICT pb = b.Row(thread);
11311132
double* HWY_RESTRICT buf = bufs.Row(thread);

ops/matmul.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ std::vector<MMConfig> MMCandidates(const CacheInfo& cache, size_t M, size_t K,
351351

352352
MatMulEnv::MatMulEnv(ThreadingContext& ctx)
353353
: ctx(ctx), A_BF(ctx.allocator), C_tiles(ctx) {
354-
const size_t num_clusters = ctx.pools.AllClusters(/*pkg_idx=*/0).NumWorkers();
354+
const size_t num_clusters = ctx.pools.NumClusters();
355355
per_cluster.resize(num_clusters);
356356
for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) {
357357
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(kMaxBatchSize)); // C
@@ -368,7 +368,7 @@ void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC) {
368368

369369
PROFILER_ZONE("Startup.BindB");
370370

371-
const size_t node = ctx.topology.GetCluster(/*pkg_idx=*/0, 0).Node();
371+
const size_t node = ctx.topology.GetCluster(0).Node();
372372
uintptr_t begin = reinterpret_cast<uintptr_t>(B.RowBytes(0));
373373
uintptr_t end = begin + B.Rows() * B.Stride() * B.ElementBytes();
374374
// B row padding is less than the page size, so only bind the subset that
@@ -394,7 +394,7 @@ void BindC(ThreadingContext& ctx, MatPtr& C) {
394394
const size_t end = hwy::RoundDownTo(cols_c.end() * C.ElementBytes(),
395395
allocator.BasePageBytes());
396396

397-
const size_t node = ctx.topology.GetCluster(/*pkg_idx=*/0, 0).Node();
397+
const size_t node = ctx.topology.GetCluster(0).Node();
398398
bool ok = true;
399399
for (size_t im = 0; im < C.Rows(); ++im) {
400400
ok &= allocator.BindMemory(C.RowBytes(im) + begin, end - begin, node);

ops/matmul.h

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,7 @@ struct MMParallelWithinCluster {
105105
size_t inner_tasks, size_t cluster_idx, const Func& func) const {
106106
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
107107

108-
const size_t pkg_idx = 0;
109-
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
108+
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx);
110109
const size_t base = ctx.Worker(cluster_idx);
111110

112111
const IndexRangePartition ranges_n = StaticPartition(
@@ -122,8 +121,7 @@ struct MMParallelWithinCluster {
122121
const IndexRangePartition& ranges_mc,
123122
const IndexRangePartition& ranges_nc, size_t cluster_idx,
124123
const Func& func) const {
125-
const size_t pkg_idx = 0;
126-
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
124+
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx);
127125
const size_t base = ctx.Worker(cluster_idx);
128126

129127
// Low-batch: avoid Divide/Remainder.
@@ -143,8 +141,7 @@ struct MMParallelWithinCluster {
143141
template <class Func>
144142
void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
145143
size_t cluster_idx, const Func& func) const {
146-
const size_t pkg_idx = 0;
147-
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
144+
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx);
148145
const size_t base = ctx.Worker(cluster_idx);
149146

150147
cluster.Run(
@@ -164,12 +161,11 @@ struct MMParallelHierarchical {
164161
HWY_DASSERT(caller_cluster_idx == 0);
165162

166163
// Single cluster: parallel-for over static partition of `range_n`.
167-
const size_t pkg_idx = 0;
168-
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx);
164+
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters();
169165
const size_t num_clusters = all_clusters.NumWorkers();
170166
if (num_clusters == 1) {
171167
const size_t cluster_idx = 0;
172-
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
168+
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx);
173169
const IndexRangePartition ranges_n = StaticPartition(
174170
range_n, cluster.NumWorkers() * inner_tasks, n_multiple);
175171
return ParallelizeOneRange(
@@ -185,7 +181,7 @@ struct MMParallelHierarchical {
185181
ParallelizeOneRange(
186182
ranges_n, all_clusters,
187183
[&](const IndexRange& n_range, const size_t cluster_idx) {
188-
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
184+
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx);
189185
const size_t cluster_base = ctx.Worker(cluster_idx);
190186
// Parallel-for over sub-ranges of `cluster_range` within the cluster.
191187
const IndexRangePartition worker_ranges = StaticPartition(
@@ -206,17 +202,16 @@ struct MMParallelHierarchical {
206202
const IndexRangePartition& ranges_nc,
207203
HWY_MAYBE_UNUSED size_t caller_cluster_idx,
208204
const Func& func) const {
209-
const size_t pkg_idx = 0;
210205
HWY_DASSERT(caller_cluster_idx == 0);
211206

212-
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx);
207+
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters();
213208
// `all_clusters` is a pool with one worker per cluster in a package.
214209
const size_t num_clusters = all_clusters.NumWorkers();
215210
// Single (big) cluster: collapse two range indices into one parallel-for
216211
// to reduce the number of fork-joins.
217212
if (num_clusters == 1) {
218213
const size_t cluster_idx = 0;
219-
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
214+
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx);
220215
// Low-batch: avoid Divide/Remainder.
221216
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
222217
return ParallelizeOneRange(
@@ -237,7 +232,7 @@ struct MMParallelHierarchical {
237232
ranges_nc, all_clusters,
238233
[&](const IndexRange range_nc, size_t cluster_idx) {
239234
const size_t cluster_base = ctx.Worker(cluster_idx);
240-
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
235+
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx);
241236
ParallelizeOneRange(ranges_mc, cluster,
242237
[&](const IndexRange& range_mc, size_t worker) {
243238
func(range_mc, range_nc, cluster_base + worker);

ops/matmul_test.cc

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -191,29 +191,22 @@ HWY_INLINE void MatMulSlow(const MatPtrT<TA> A, const MatPtrT<TB> B,
191191
const IndexRange all_cols_c(0, C.Cols());
192192

193193
NestedPools& pools = env.ctx.pools;
194-
hwy::ThreadPool& all_packages = pools.AllPackages();
195-
const IndexRangePartition get_row_c =
196-
StaticPartition(all_rows_c, all_packages.NumWorkers(), 1);
194+
hwy::ThreadPool& all_clusters = pools.AllClusters();
195+
const size_t multiple = env.ctx.allocator.QuantumBytes() / sizeof(TB);
196+
const IndexRangePartition get_col_c =
197+
StaticPartition(all_cols_c, all_clusters.NumWorkers(), multiple);
197198
ParallelizeOneRange(
198-
get_row_c, all_packages,
199-
[&](const IndexRange& rows_c, size_t package_idx) HWY_ATTR {
200-
hwy::ThreadPool& all_clusters = pools.AllClusters(package_idx);
201-
const size_t multiple = env.ctx.allocator.QuantumBytes() / sizeof(TB);
202-
const IndexRangePartition get_col_c =
203-
StaticPartition(all_cols_c, all_clusters.NumWorkers(), multiple);
204-
ParallelizeOneRange(
205-
get_col_c, all_clusters,
206-
[&](const IndexRange& cols_c, size_t cluster_idx) HWY_ATTR {
207-
for (size_t r : rows_c) {
208-
TC* HWY_RESTRICT C_row = C.Row(r);
209-
for (size_t c : cols_c) {
210-
const float add = add_row ? add_row[c] : 0.0f;
211-
const float dot =
212-
Dot(df, b_span, c * B.Stride(), A.Row(r), A.Cols());
213-
C_row[c] = hwy::ConvertScalarTo<TC>(add + scale * dot);
214-
}
215-
}
216-
});
199+
get_col_c, all_clusters,
200+
[&](const IndexRange& cols_c, size_t cluster_idx) HWY_ATTR {
201+
for (size_t r : all_rows_c) {
202+
TC* HWY_RESTRICT C_row = C.Row(r);
203+
for (size_t c : cols_c) {
204+
const float add = add_row ? add_row[c] : 0.0f;
205+
const float dot =
206+
Dot(df, b_span, c * B.Stride(), A.Row(r), A.Cols());
207+
C_row[c] = hwy::ConvertScalarTo<TC>(add + scale * dot);
208+
}
209+
}
217210
});
218211
}
219212

util/allocator.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ CacheInfo::CacheInfo(const BoundedTopology& topology) {
139139

140140
step_bytes_ = HWY_MAX(line_bytes_, vector_bytes_);
141141

142-
const BoundedTopology::Cluster& cluster = topology.GetCluster(0, 0);
142+
const BoundedTopology::Cluster& cluster = topology.GetCluster(0);
143143
if (const hwy::Cache* caches = hwy::DataCaches()) {
144144
l1_bytes_ = caches[1].size_kib << 10;
145145
l2_bytes_ = caches[2].size_kib << 10;

util/allocator.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ class Allocator {
169169
bool ShouldBind() const { return should_bind_; }
170170

171171
// Attempts to move(!) `[p, p + bytes)` to the given NUMA node, which is
172-
// typically `BoundedTopology::GetCluster(package_idx, cluster_idx).node`.
172+
// typically `BoundedTopology::GetCluster(cluster_idx).node`.
173173
// Writes zeros to SOME of the memory. Only call if `ShouldBind()`.
174174
// `p` and `bytes` must be multiples of `QuantumBytes()`.
175175
bool BindMemory(void* p, size_t bytes, size_t node) const;

0 commit comments

Comments
 (0)