Skip to content

Commit 091b456

Browse files
jan-wassenbergcopybara-github
authored andcommitted
Minor: ParallelismStrategy->Parallelism
PiperOrigin-RevId: 828936578
1 parent a344a70 commit 091b456

File tree

13 files changed

+44
-46
lines changed

13 files changed

+44
-46
lines changed

compression/test_util-inl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ MatStorageT<MatT> GenerateMat(const Extents2D& extents, MatPadding padding,
105105
MatStorageT<float> raw("raw", extents, ctx.allocator, MatPadding::kPacked);
106106
MatStorageT<MatT> compressed("mat", extents, ctx.allocator, padding);
107107
const float scale = SfpStream::kMax / extents.Area();
108-
ParallelFor(ParallelismStrategy::kFlat, extents.rows, ctx, /*cluster_idx=*/0,
108+
ParallelFor(Parallelism::kFlat, extents.rows, ctx, /*cluster_idx=*/0,
109109
Callers::kTest, [&](size_t r, size_t thread) {
110110
float* HWY_RESTRICT row = raw.Row(r);
111111
for (size_t c = 0; c < extents.cols; c++) {
@@ -134,7 +134,7 @@ MatStorageT<MatT> GenerateTransposedMat(const Extents2D extents,
134134
MatStorageT<float> raw("raw", extents, ctx.allocator, MatPadding::kPacked);
135135
MatStorageT<MatT> compressed("trans", extents, ctx.allocator, padding);
136136
const float scale = SfpStream::kMax / extents.Area();
137-
ParallelFor(ParallelismStrategy::kFlat, extents.rows, ctx, /*cluster_idx=*/0,
137+
ParallelFor(Parallelism::kFlat, extents.rows, ctx, /*cluster_idx=*/0,
138138
Callers::kTest, [&](size_t r, size_t thread) {
139139
float* HWY_RESTRICT row = raw.Row(r);
140140
for (size_t c = 0; c < extents.cols; c++) {

gemma/attention.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
278278
// Note that 2D parallelism is not worth the fork/join overhead because the
279279
// tasks are very lightweight.
280280
ParallelFor(
281-
ParallelismStrategy::kFlat, kv_heads * num_interleaved, env.ctx,
281+
Parallelism::kFlat, kv_heads * num_interleaved, env.ctx,
282282
/*cluster_idx=*/0, Callers::kAttComputeQKV,
283283
[&](size_t task, size_t worker) HWY_ATTR {
284284
const size_t head = task % kv_heads;

gemma/flash_attention.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ static void TransposeQ(const MatPtrT<float>& q, MatPtrT<BF16>& q_t,
8585
{
8686
const size_t num_tasks = hwy::DivCeil(q_t.Rows(), kNF);
8787
// Better than kFlat.
88-
ParallelFor(ParallelismStrategy::kHierarchical, num_tasks, ctx,
88+
ParallelFor(Parallelism::kHierarchical, num_tasks, ctx,
8989
/*cluster_idx=*/0, Callers::kFlashTransposeQ, func);
9090
}
9191
}
@@ -124,7 +124,7 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
124124
{
125125
// kHierarchical is not worth the extra sync overhead because the tasks are
126126
// very lightweight.
127-
ParallelFor(ParallelismStrategy::kFlat, num_tokens * qbatch.Size(), ctx,
127+
ParallelFor(Parallelism::kFlat, num_tokens * qbatch.Size(), ctx,
128128
/*cluster_idx=*/0, Callers::kFlashRMSNormAndPositionalEncoding,
129129
func);
130130
}
@@ -619,7 +619,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
619619
const hwy::Divisor div_qbatch(qbatch.Size());
620620
// Compress q to q_bf.
621621
ParallelFor(
622-
ParallelismStrategy::kWithinCluster, activations.q.Rows(), ctx,
622+
Parallelism::kWithinCluster, activations.q.Rows(), ctx,
623623
/*cluster_idx=*/0, Callers::kFlashAttention,
624624
[&](size_t row, size_t worker) {
625625
CompressPerThread tls;

gemma/gemma-inl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ template <class Mat>
7070
void ActivationBatched(
7171
ActivationType activation, Mat& c1, ThreadingContext& ctx,
7272
size_t cluster_idx = 0,
73-
ParallelismStrategy parallelism = ParallelismStrategy::kFlat) {
73+
Parallelism parallelism = Parallelism::kFlat) {
7474
using T = typename Mat::T;
7575
ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx,
7676
Callers::kActivationBatched, [&](uint64_t task, size_t worker) {
@@ -115,7 +115,7 @@ template <class Mat1, class Mat2>
115115
HWY_NOINLINE void ActivationBatched(
116116
ActivationType activation, Mat1& c1, const Mat2* c2, ThreadingContext& ctx,
117117
size_t cluster_idx = 0,
118-
ParallelismStrategy parallelism = ParallelismStrategy::kFlat) {
118+
Parallelism parallelism = Parallelism::kFlat) {
119119
HWY_DASSERT(c1.SameShape(*c2));
120120
if (c2 && c2->HasPtr()) {
121121
ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx,

gemma/gemma.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ static void SampleAndStream(const ModelConfig& config,
426426
timing_info.NotifyGenerated(non_eos.Count());
427427

428428
ParallelFor(
429-
ParallelismStrategy::kFlat, qbatch.Size(), env.ctx,
429+
Parallelism::kFlat, qbatch.Size(), env.ctx,
430430
/*cluster_idx=*/0, Callers::kSampleAndStream,
431431
[&](size_t qi, size_t worker) {
432432
if (!non_eos.Get(qi)) return;

gemma/weights.cc

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -431,12 +431,12 @@ void WeightsPtrs::CopyFrom(const WeightsPtrs& other) {
431431
void WeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
432432
ThreadingContext& ctx) {
433433
const size_t cluster_idx = 0;
434-
ParallelFor(ParallelismStrategy::kFlat, c_layers.size(), ctx, cluster_idx,
434+
ParallelFor(Parallelism::kFlat, c_layers.size(), ctx, cluster_idx,
435435
Callers::kFixupWeights, [&](uint64_t layer, size_t /*worker*/) {
436436
GetLayer(layer)->Fixup(mat_owners, ctx);
437437
});
438438

439-
ParallelFor(ParallelismStrategy::kFlat, vit_layers.size(), ctx, cluster_idx,
439+
ParallelFor(Parallelism::kFlat, vit_layers.size(), ctx, cluster_idx,
440440
Callers::kFixupWeights, [&](uint64_t layer, size_t /*worker*/) {
441441
VitLayer(layer)->Fixup(mat_owners, ctx);
442442
});
@@ -527,7 +527,7 @@ static void AllocateAndBindAll(std::vector<TensorToRead>& tensors,
527527

528528
// Allocate in parallel because faulting in large tensors is slow.
529529
ParallelFor(
530-
ParallelismStrategy::kFlat, tensors.size(), ctx, /*cluster_idx=*/0,
530+
Parallelism::kFlat, tensors.size(), ctx, /*cluster_idx=*/0,
531531
Callers::kAllocateAndBindAll, [&](uint64_t task, size_t /*thread*/) {
532532
TensorToRead& tensor = tensors[task];
533533
MatPtr& mat = *tensor.mat;
@@ -586,10 +586,9 @@ static void DecompressToBF16(MatPtr& mat,
586586
static void ReadAllToBF16(const std::vector<TensorToRead>& tensors,
587587
const BlobReader& reader, ThreadingContext& ctx) {
588588
// Especially TSAN is slow enough to warrant hierarchical parallelism.
589-
const ParallelismStrategy strategy = HWY_IS_DEBUG_BUILD
590-
? ParallelismStrategy::kHierarchical
591-
: ParallelismStrategy::kFlat;
592-
ParallelFor(strategy, tensors.size(), ctx, /*cluster_idx=*/0,
589+
const Parallelism parallelism =
590+
HWY_IS_DEBUG_BUILD ? Parallelism::kHierarchical : Parallelism::kFlat;
591+
ParallelFor(parallelism, tensors.size(), ctx, /*cluster_idx=*/0,
593592
Callers::kReadAllToBF16, [&](uint64_t task, size_t thread) {
594593
GCPP_ZONE(ctx, thread, Zones::kStartupWeightsReadAllToBF16);
595594
const TensorToRead& tensor = tensors[task];
@@ -677,7 +676,7 @@ static void ReadBatches(const BlobReader& reader,
677676
const std::vector<IOBatch>& batches,
678677
ThreadingContext& ctx) {
679678
// >5x speedup from parallel reads when cached.
680-
ParallelFor(ParallelismStrategy::kHierarchical, batches.size(), ctx,
679+
ParallelFor(Parallelism::kHierarchical, batches.size(), ctx,
681680
/*cluster_idx=*/0, Callers::kReadBatches,
682681
[&](uint64_t task, size_t thread) {
683682
GCPP_ZONE(ctx, thread, Zones::kStartupWeightsReadBatches);

io/blob_compare.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ void ReadBlobs(BlobReader& reader, const RangeVec& ranges, BlobVec& blobs,
106106
ThreadingContext& ctx, size_t cluster_idx) {
107107
HWY_ASSERT(reader.Keys().size() == blobs.size());
108108
HWY_ASSERT(ranges.size() == blobs.size());
109-
ParallelFor(ParallelismStrategy::kWithinCluster, blobs.size(), ctx,
109+
ParallelFor(Parallelism::kWithinCluster, blobs.size(), ctx,
110110
cluster_idx, Callers::kTest, [&](size_t i, size_t /*thread*/) {
111111
HWY_ASSERT(ranges[i].bytes == blobs[i].size());
112112
reader.file().Read(ranges[i].offset, ranges[i].bytes,
@@ -122,7 +122,7 @@ void ReadBothBlobs(BlobReader& reader1, BlobReader& reader2,
122122
const double t0 = hwy::platform::Now();
123123
HWY_WARN("Reading %zu GiB, %zu clusters: ", total_bytes >> 30,
124124
ctx.pools.NumClusters());
125-
ParallelFor(ParallelismStrategy::kAcrossClusters, 2, ctx, 0, Callers::kTest,
125+
ParallelFor(Parallelism::kAcrossClusters, 2, ctx, 0, Callers::kTest,
126126
[&](const size_t task, size_t cluster_idx) {
127127
ReadBlobs(task ? reader1 : reader2, task ? ranges1 : ranges2,
128128
task ? blobs1 : blobs2, ctx, cluster_idx);
@@ -189,7 +189,7 @@ void CompareBlobs(const KeyVec& keys, BlobVec& blobs1, BlobVec& blobs2,
189189
const double t0 = hwy::platform::Now();
190190
std::atomic<size_t> blobs_equal{};
191191
std::atomic<size_t> blobs_diff{};
192-
ParallelFor(ParallelismStrategy::kHierarchical, keys.size(), ctx, 0,
192+
ParallelFor(Parallelism::kHierarchical, keys.size(), ctx, 0,
193193
Callers::kTest, [&](size_t i, size_t /*thread*/) {
194194
const size_t mismatches =
195195
BlobDifferences(blobs1[i], blobs2[i], keys[i]);

io/blob_store.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -488,11 +488,10 @@ void BlobWriter::Add(const std::string& key, const void* data, size_t bytes) {
488488
EnqueueChunks(keys_.size() - 1, curr_offset_, bytes,
489489
static_cast<const uint8_t*>(data), writes);
490490

491-
const ParallelismStrategy strategy = file_->IsAppendOnly()
492-
? ParallelismStrategy::kNone
493-
: ParallelismStrategy::kFlat;
491+
const Parallelism parallelism =
492+
file_->IsAppendOnly() ? Parallelism::kNone : Parallelism::kFlat;
494493
ParallelFor(
495-
strategy, writes.size(), ctx_,
494+
parallelism, writes.size(), ctx_,
496495
/*cluster_idx=*/0, Callers::kBlobWriter,
497496
[this, &writes](uint64_t i, size_t /*thread*/) {
498497
const BlobRange& range = writes[i].range;

io/blob_store_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ TEST(BlobStoreTest, TestNumBlobs) {
130130
HWY_ASSERT_EQ(reader.Keys().size(), num_blobs);
131131

132132
ParallelFor(
133-
ParallelismStrategy::kFlat, num_blobs, ctx, /*cluster_idx=*/0,
133+
Parallelism::kFlat, num_blobs, ctx, /*cluster_idx=*/0,
134134
Callers::kTest, [&](uint64_t i, size_t /*thread*/) {
135135
HWY_ASSERT_STRING_EQ(reader.Keys()[i].c_str(),
136136
std::to_string(i).c_str());

ops/dot_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1126,7 +1126,7 @@ void TestAllDot() {
11261126
std::array<DotStats, kMaxWorkers> all_stats;
11271127

11281128
ParallelFor(
1129-
ParallelismStrategy::kWithinCluster, kReps, ctx, 0, Callers::kTest,
1129+
Parallelism::kWithinCluster, kReps, ctx, 0, Callers::kTest,
11301130
[&](size_t rep, size_t thread) {
11311131
float* HWY_RESTRICT pa = a.Row(thread);
11321132
float* HWY_RESTRICT pb = b.Row(thread);

0 commit comments

Comments
 (0)