2828#include " util/threading_context.h"
2929#include " hwy/aligned_allocator.h" // Span
3030#include " hwy/base.h"
31- #include " hwy/contrib/thread_pool/thread_pool.h"
3231#include " hwy/timer.h"
3332
3433namespace gcpp {
@@ -104,27 +103,31 @@ BlobVec ReserveMemory(const RangeVec& ranges, BytePtr& all_blobs, size_t& pos) {
104103// Reads one set of blobs in parallel (helpful if in disk cache).
105104// Aborts on error.
106105void ReadBlobs (BlobReader& reader, const RangeVec& ranges, BlobVec& blobs,
107- hwy::ThreadPool& pool ) {
106+ ThreadingContext& ctx, size_t cluster_idx ) {
108107 HWY_ASSERT (reader.Keys ().size () == blobs.size ());
109108 HWY_ASSERT (ranges.size () == blobs.size ());
110- pool.Run (0 , blobs.size (), [&](size_t i, size_t /* thread*/ ) {
111- HWY_ASSERT (ranges[i].bytes == blobs[i].size ());
112- reader.file ().Read (ranges[i].offset , ranges[i].bytes , blobs[i].data ());
113- });
109+ ParallelFor (ParallelismStrategy::kWithinCluster , blobs.size (), ctx,
110+ cluster_idx, [&](size_t i, size_t /* thread*/ ) {
111+ HWY_ASSERT (ranges[i].bytes == blobs[i].size ());
112+ reader.file ().Read (ranges[i].offset , ranges[i].bytes ,
113+ blobs[i].data ());
114+ });
114115}
115116
116117// Parallelizes ReadBlobs across (two) packages, if available.
117118void ReadBothBlobs (BlobReader& reader1, BlobReader& reader2,
118119 const RangeVec& ranges1, const RangeVec& ranges2,
119120 size_t total_bytes, BlobVec& blobs1, BlobVec& blobs2,
120- NestedPools& pools ) {
121+ ThreadingContext& ctx ) {
121122 const double t0 = hwy::platform::Now ();
122- HWY_WARN (" Reading %zu GiB, %zux%zu cores: " , total_bytes >> 30 ,
123- pools.AllPackages ().NumWorkers (), pools.Pool ().NumWorkers ());
124- pools.AllPackages ().Run (0 , 2 , [&](size_t task, size_t pkg_idx) {
125- ReadBlobs (task ? reader2 : reader1, task ? ranges2 : ranges1,
126- task ? blobs2 : blobs1, pools.Pool (pkg_idx));
127- });
123+ HWY_WARN (" Reading %zu GiB, %zu clusters: " , total_bytes >> 30 ,
124+ ctx.pools .NumClusters ());
125+ ParallelFor (ParallelismStrategy::kAcrossClusters , 2 , ctx, 0 ,
126+ [&](const size_t task, size_t cluster_idx) {
127+ ReadBlobs (task ? reader1 : reader2, task ? ranges1 : ranges2,
128+ task ? blobs1 : blobs2, ctx, cluster_idx);
129+ });
130+
128131 const double t1 = hwy::platform::Now ();
129132 HWY_WARN (" %.1f GB/s\n " , total_bytes / (t1 - t0) * 1E-9 );
130133}
@@ -181,29 +184,23 @@ size_t BlobDifferences(const ByteSpan data1, const ByteSpan data2,
181184}
182185
183186void CompareBlobs (const KeyVec& keys, BlobVec& blobs1, BlobVec& blobs2,
184- size_t total_bytes, NestedPools& pools ) {
187+ size_t total_bytes, ThreadingContext& ctx ) {
185188 HWY_WARN (" Comparing %zu blobs in parallel: " , keys.size ());
186189 const double t0 = hwy::platform::Now ();
187190 std::atomic<size_t > blobs_equal{};
188191 std::atomic<size_t > blobs_diff{};
189- const IndexRangePartition ranges = StaticPartition (
190- IndexRange (0 , keys.size ()), pools.AllPackages ().NumWorkers (), 1 );
191- ParallelizeOneRange (
192- ranges, pools.AllPackages (),
193- [&](const IndexRange& range, size_t pkg_idx) {
194- pools.Pool (pkg_idx).Run (
195- range.begin (), range.end (), [&](size_t i, size_t /* thread*/ ) {
196- const size_t mismatches =
197- BlobDifferences (blobs1[i], blobs2[i], keys[i]);
198- if (mismatches != 0 ) {
199- HWY_WARN (" key %s has %zu mismatches in %zu bytes!\n " ,
200- keys[i].c_str (), mismatches, blobs1[i].size ());
201- blobs_diff.fetch_add (1 );
202- } else {
203- blobs_equal.fetch_add (1 );
204- }
205- });
206- });
192+ ParallelFor (ParallelismStrategy::kHierarchical , keys.size (), ctx, 0 ,
193+ [&](size_t i, size_t /* thread*/ ) {
194+ const size_t mismatches =
195+ BlobDifferences (blobs1[i], blobs2[i], keys[i]);
196+ if (mismatches != 0 ) {
197+ HWY_WARN (" key %s has %zu mismatches in %zu bytes!\n " ,
198+ keys[i].c_str (), mismatches, blobs1[i].size ());
199+ blobs_diff.fetch_add (1 );
200+ } else {
201+ blobs_equal.fetch_add (1 );
202+ }
203+ });
207204 const double t1 = hwy::platform::Now ();
208205 HWY_WARN (" %.1f GB/s; total blob matches=%zu, mismatches=%zu\n " ,
209206 total_bytes / (t1 - t0) * 1E-9 , blobs_equal.load (),
@@ -230,9 +227,9 @@ void ReadAndCompareBlobs(const Path& path1, const Path& path2) {
230227 ThreadingArgs args;
231228 ThreadingContext ctx (args);
232229 ReadBothBlobs (reader1, reader2, ranges1, ranges2, total_bytes, blobs1, blobs2,
233- ctx. pools );
230+ ctx);
234231
235- CompareBlobs (reader1.Keys (), blobs1, blobs2, total_bytes, ctx. pools );
232+ CompareBlobs (reader1.Keys (), blobs1, blobs2, total_bytes, ctx);
236233}
237234
238235} // namespace gcpp
0 commit comments