Skip to content

Commit c9dae43

Browse files
[memprof] Add access checks to PortableMemInfoBlock::get* (#90121)
commit 4c8ec8f Author: Kazu Hirata <[email protected]> Date: Wed Apr 24 16:25:35 2024 -0700 introduced the idea of serializing/deserializing a subset of the fields in PortableMemInfoBlock. While it reduces the size of the indexed MemProf profile file, we now could inadvertently access unavailable fields and go without noticing. To protect ourselves from the risk, this patch adds access checks to PortableMemInfoBlock::get* methods by embedding a bit set representing available fields into PortableMemInfoBlock.
1 parent 3526020 commit c9dae43

File tree

3 files changed

+97
-12
lines changed

3 files changed

+97
-12
lines changed

llvm/include/llvm/ProfileData/MemProf.h

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define LLVM_PROFILEDATA_MEMPROF_H_
33

44
#include "llvm/ADT/MapVector.h"
5+
#include "llvm/ADT/STLForwardCompat.h"
56
#include "llvm/ADT/STLFunctionalExtras.h"
67
#include "llvm/ADT/SmallVector.h"
78
#include "llvm/IR/GlobalValue.h"
@@ -10,6 +11,7 @@
1011
#include "llvm/Support/EndianStream.h"
1112
#include "llvm/Support/raw_ostream.h"
1213

14+
#include <bitset>
1315
#include <cstdint>
1416
#include <optional>
1517

@@ -55,7 +57,10 @@ MemProfSchema getHotColdSchema();
5557
// deserialize methods.
5658
struct PortableMemInfoBlock {
5759
PortableMemInfoBlock() = default;
58-
explicit PortableMemInfoBlock(const MemInfoBlock &Block) {
60+
explicit PortableMemInfoBlock(const MemInfoBlock &Block,
61+
const MemProfSchema &IncomingSchema) {
62+
for (const Meta Id : IncomingSchema)
63+
Schema.set(llvm::to_underlying(Id));
5964
#define MIBEntryDef(NameTag, Name, Type) Name = Block.Name;
6065
#include "llvm/ProfileData/MIBEntryDef.inc"
6166
#undef MIBEntryDef
@@ -67,10 +72,12 @@ struct PortableMemInfoBlock {
6772

6873
// Read the contents of \p Ptr based on the \p Schema to populate the
6974
// MemInfoBlock member.
70-
void deserialize(const MemProfSchema &Schema, const unsigned char *Ptr) {
75+
void deserialize(const MemProfSchema &IncomingSchema,
76+
const unsigned char *Ptr) {
7177
using namespace support;
7278

73-
for (const Meta Id : Schema) {
79+
Schema.reset();
80+
for (const Meta Id : IncomingSchema) {
7481
switch (Id) {
7582
#define MIBEntryDef(NameTag, Name, Type) \
7683
case Meta::Name: { \
@@ -82,6 +89,8 @@ struct PortableMemInfoBlock {
8289
llvm_unreachable("Unknown meta type id, is the profile collected from "
8390
"a newer version of the runtime?");
8491
}
92+
93+
Schema.set(llvm::to_underlying(Id));
8594
}
8695
}
8796

@@ -114,17 +123,29 @@ struct PortableMemInfoBlock {
114123
#undef MIBEntryDef
115124
}
116125

126+
// Return the schema, only for unit tests.
127+
std::bitset<llvm::to_underlying(Meta::Size)> getSchema() const {
128+
return Schema;
129+
}
130+
117131
// Define getters for each type which can be called by analyses.
118132
#define MIBEntryDef(NameTag, Name, Type) \
119-
Type get##Name() const { return Name; }
133+
Type get##Name() const { \
134+
assert(Schema[llvm::to_underlying(Meta::Name)]); \
135+
return Name; \
136+
}
120137
#include "llvm/ProfileData/MIBEntryDef.inc"
121138
#undef MIBEntryDef
122139

123140
void clear() { *this = PortableMemInfoBlock(); }
124141

125142
bool operator==(const PortableMemInfoBlock &Other) const {
143+
if (Other.Schema != Schema)
144+
return false;
145+
126146
#define MIBEntryDef(NameTag, Name, Type) \
127-
if (Other.get##Name() != get##Name()) \
147+
if (Schema[llvm::to_underlying(Meta::Name)] && \
148+
Other.get##Name() != get##Name()) \
128149
return false;
129150
#include "llvm/ProfileData/MIBEntryDef.inc"
130151
#undef MIBEntryDef
@@ -155,6 +176,9 @@ struct PortableMemInfoBlock {
155176
}
156177

157178
private:
179+
// The set of available fields, indexed by Meta::Name.
180+
std::bitset<llvm::to_underlying(Meta::Size)> Schema;
181+
158182
#define MIBEntryDef(NameTag, Name, Type) Type Name = Type();
159183
#include "llvm/ProfileData/MIBEntryDef.inc"
160184
#undef MIBEntryDef
@@ -296,8 +320,9 @@ struct IndexedAllocationInfo {
296320

297321
IndexedAllocationInfo() = default;
298322
IndexedAllocationInfo(ArrayRef<FrameId> CS, CallStackId CSId,
299-
const MemInfoBlock &MB)
300-
: CallStack(CS.begin(), CS.end()), CSId(CSId), Info(MB) {}
323+
const MemInfoBlock &MB,
324+
const MemProfSchema &Schema = getFullSchema())
325+
: CallStack(CS.begin(), CS.end()), CSId(CSId), Info(MB, Schema) {}
301326

302327
// Returns the size in bytes when this allocation info struct is serialized.
303328
size_t serializedSize(const MemProfSchema &Schema,

llvm/unittests/ProfileData/InstrProfTest.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -407,13 +407,13 @@ IndexedMemProfRecord makeRecord(
407407
IndexedMemProfRecord
408408
makeRecordV2(std::initializer_list<::llvm::memprof::CallStackId> AllocFrames,
409409
std::initializer_list<::llvm::memprof::CallStackId> CallSiteFrames,
410-
const MemInfoBlock &Block) {
410+
const MemInfoBlock &Block, const memprof::MemProfSchema &Schema) {
411411
llvm::memprof::IndexedMemProfRecord MR;
412412
for (const auto &CSId : AllocFrames)
413413
// We don't populate IndexedAllocationInfo::CallStack because we use it only
414414
// in Version0 and Version1.
415415
MR.AllocSites.emplace_back(::llvm::SmallVector<memprof::FrameId>(), CSId,
416-
Block);
416+
Block, Schema);
417417
for (const auto &CSId : CallSiteFrames)
418418
MR.CallSiteIds.push_back(CSId);
419419
return MR;
@@ -506,7 +506,7 @@ TEST_F(InstrProfTest, test_memprof_v2_full_schema) {
506506

507507
const IndexedMemProfRecord IndexedMR = makeRecordV2(
508508
/*AllocFrames=*/{0x111, 0x222},
509-
/*CallSiteFrames=*/{0x333}, MIB);
509+
/*CallSiteFrames=*/{0x333}, MIB, memprof::getFullSchema());
510510
const FrameIdMapTy IdToFrameMap = getFrameMapping();
511511
const auto CSIdToCallStackMap = getCallStackMapping();
512512
for (const auto &I : IdToFrameMap) {
@@ -548,7 +548,7 @@ TEST_F(InstrProfTest, test_memprof_v2_partial_schema) {
548548

549549
const IndexedMemProfRecord IndexedMR = makeRecordV2(
550550
/*AllocFrames=*/{0x111, 0x222},
551-
/*CallSiteFrames=*/{0x333}, MIB);
551+
/*CallSiteFrames=*/{0x333}, MIB, memprof::getHotColdSchema());
552552
const FrameIdMapTy IdToFrameMap = getFrameMapping();
553553
const auto CSIdToCallStackMap = getCallStackMapping();
554554
for (const auto &I : IdToFrameMap) {

llvm/unittests/ProfileData/MemProfTest.cpp

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "llvm/ProfileData/MemProf.h"
22
#include "llvm/ADT/DenseMap.h"
33
#include "llvm/ADT/MapVector.h"
4+
#include "llvm/ADT/STLForwardCompat.h"
45
#include "llvm/DebugInfo/DIContext.h"
56
#include "llvm/DebugInfo/Symbolize/SymbolizableModule.h"
67
#include "llvm/IR/Value.h"
@@ -241,7 +242,7 @@ TEST(MemProf, PortableWrapper) {
241242
/*dealloc_cpu=*/4);
242243

243244
const auto Schema = llvm::memprof::getFullSchema();
244-
PortableMemInfoBlock WriteBlock(Info);
245+
PortableMemInfoBlock WriteBlock(Info, Schema);
245246

246247
std::string Buffer;
247248
llvm::raw_string_ostream OS(Buffer);
@@ -326,6 +327,65 @@ TEST(MemProf, RecordSerializationRoundTripVerion2) {
326327
EXPECT_EQ(Record, GotRecord);
327328
}
328329

330+
TEST(MemProf, RecordSerializationRoundTripVersion2HotColdSchema) {
331+
const auto Schema = llvm::memprof::getHotColdSchema();
332+
333+
MemInfoBlock Info;
334+
Info.AllocCount = 11;
335+
Info.TotalSize = 22;
336+
Info.TotalLifetime = 33;
337+
Info.TotalLifetimeAccessDensity = 44;
338+
339+
llvm::SmallVector<llvm::memprof::CallStackId> CallStackIds = {0x123, 0x456};
340+
341+
llvm::SmallVector<llvm::memprof::CallStackId> CallSiteIds = {0x333, 0x444};
342+
343+
IndexedMemProfRecord Record;
344+
for (const auto &CSId : CallStackIds) {
345+
// Use the same info block for both allocation sites.
346+
Record.AllocSites.emplace_back(llvm::SmallVector<FrameId>(), CSId, Info,
347+
Schema);
348+
}
349+
Record.CallSiteIds.assign(CallSiteIds);
350+
351+
std::bitset<llvm::to_underlying(Meta::Size)> SchemaBitSet;
352+
for (auto Id : Schema)
353+
SchemaBitSet.set(llvm::to_underlying(Id));
354+
355+
// Verify that SchemaBitSet has the fields we expect and nothing else, which
356+
// we check with count().
357+
EXPECT_EQ(SchemaBitSet.count(), 4U);
358+
EXPECT_TRUE(SchemaBitSet[llvm::to_underlying(Meta::AllocCount)]);
359+
EXPECT_TRUE(SchemaBitSet[llvm::to_underlying(Meta::TotalSize)]);
360+
EXPECT_TRUE(SchemaBitSet[llvm::to_underlying(Meta::TotalLifetime)]);
361+
EXPECT_TRUE(
362+
SchemaBitSet[llvm::to_underlying(Meta::TotalLifetimeAccessDensity)]);
363+
364+
// Verify that Schema has propagated all the way to the Info field in each
365+
// IndexedAllocationInfo.
366+
ASSERT_THAT(Record.AllocSites, ::SizeIs(2));
367+
EXPECT_EQ(Record.AllocSites[0].Info.getSchema(), SchemaBitSet);
368+
EXPECT_EQ(Record.AllocSites[1].Info.getSchema(), SchemaBitSet);
369+
370+
std::string Buffer;
371+
llvm::raw_string_ostream OS(Buffer);
372+
Record.serialize(Schema, OS, llvm::memprof::Version2);
373+
OS.flush();
374+
375+
const IndexedMemProfRecord GotRecord = IndexedMemProfRecord::deserialize(
376+
Schema, reinterpret_cast<const unsigned char *>(Buffer.data()),
377+
llvm::memprof::Version2);
378+
379+
// Verify that Schema comes back correctly after deserialization. Technically,
380+
// the comparison between Record and GotRecord below includes the comparison
381+
// of their Schemas, but we'll verify the Schemas on our own.
382+
ASSERT_THAT(GotRecord.AllocSites, ::SizeIs(2));
383+
EXPECT_EQ(GotRecord.AllocSites[0].Info.getSchema(), SchemaBitSet);
384+
EXPECT_EQ(GotRecord.AllocSites[1].Info.getSchema(), SchemaBitSet);
385+
386+
EXPECT_EQ(Record, GotRecord);
387+
}
388+
329389
TEST(MemProf, SymbolizationFilter) {
330390
std::unique_ptr<MockSymbolizer> Symbolizer(new MockSymbolizer());
331391

0 commit comments

Comments
 (0)