Skip to content

Commit f548d19

Browse files
authored
[libc] Fix and simplify the implementation of 'fread' on the GPU (#66948)
Summary: Previously, the `fread` operation was wrong in cases when we read less data than was requested. That is, if we tried to read N bytes while the file was in EOF, it would still copy N bytes of garbage. This is fixed by only copying over the sizes we got from locally opening it rather than just using the provided size. Additionally, this patch simplifies the interface. The output functions have special variants for writing to stdout / stderr. This is primarily an optimization for these common cases so we can avoid sending the stream as an argument which has a high delay. Because for input, we already need to start with a `send` to tell the server how much data to read, it costs us nothing to send the file along with it so this is redundant. Re-use the file encoding scheme from the other implementations, the one that stores the stream type in the LSBs of the FILE pointer.
1 parent d15f96f commit f548d19

File tree

3 files changed

+51
-62
lines changed

3 files changed

+51
-62
lines changed

libc/include/llvm-libc-types/rpc_opcodes_t.h

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,16 @@ typedef enum : unsigned short {
1515
RPC_WRITE_TO_STDOUT = 2,
1616
RPC_WRITE_TO_STDERR = 3,
1717
RPC_WRITE_TO_STREAM = 4,
18-
RPC_READ_FROM_STDIN = 5,
19-
RPC_READ_FROM_STREAM = 6,
20-
RPC_OPEN_FILE = 7,
21-
RPC_CLOSE_FILE = 8,
22-
RPC_MALLOC = 9,
23-
RPC_FREE = 10,
24-
RPC_HOST_CALL = 11,
25-
RPC_ABORT = 12,
26-
RPC_FEOF = 13,
27-
RPC_FERROR = 14,
28-
RPC_CLEARERR = 15,
18+
RPC_READ_FROM_STREAM = 5,
19+
RPC_OPEN_FILE = 6,
20+
RPC_CLOSE_FILE = 7,
21+
RPC_MALLOC = 8,
22+
RPC_FREE = 9,
23+
RPC_HOST_CALL = 10,
24+
RPC_ABORT = 11,
25+
RPC_FEOF = 12,
26+
RPC_FERROR = 13,
27+
RPC_CLEARERR = 14,
2928
} rpc_opcode_t;
3029

3130
#endif // __LLVM_LIBC_TYPES_RPC_OPCODE_H__

libc/src/stdio/gpu/file.h

Lines changed: 36 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,39 @@
1414
namespace __llvm_libc {
1515
namespace file {
1616

17+
enum Stream {
18+
File = 0,
19+
Stdin = 1,
20+
Stdout = 2,
21+
Stderr = 3,
22+
};
23+
24+
// When copying between the client and server we need to indicate if this is one
25+
// of the special streams. We do this by enocding the low order bits of the
26+
// pointer to indicate if we need to use the host's standard stream.
27+
LIBC_INLINE uintptr_t from_stream(::FILE *f) {
28+
if (f == stdin)
29+
return reinterpret_cast<uintptr_t>(f) | Stdin;
30+
if (f == stdout)
31+
return reinterpret_cast<uintptr_t>(f) | Stdout;
32+
if (f == stderr)
33+
return reinterpret_cast<uintptr_t>(f) | Stderr;
34+
return reinterpret_cast<uintptr_t>(f);
35+
}
36+
37+
// Get the associated stream out of an encoded number.
38+
LIBC_INLINE ::FILE *to_stream(uintptr_t f) {
39+
::FILE *stream = reinterpret_cast<FILE *>(f & ~0x3ull);
40+
Stream type = static_cast<Stream>(f & 0x3ull);
41+
if (type == Stdin)
42+
return stdin;
43+
if (type == Stdout)
44+
return stdout;
45+
if (type == Stderr)
46+
return stderr;
47+
return stream;
48+
}
49+
1750
template <uint16_t opcode>
1851
LIBC_INLINE uint64_t write_impl(::FILE *file, const void *data, size_t size) {
1952
uint64_t ret = 0;
@@ -42,15 +75,13 @@ LIBC_INLINE uint64_t write(::FILE *f, const void *data, size_t size) {
4275
return write_impl<RPC_WRITE_TO_STREAM>(f, data, size);
4376
}
4477

45-
template <uint16_t opcode>
4678
LIBC_INLINE uint64_t read_from_stream(::FILE *file, void *buf, size_t size) {
4779
uint64_t ret = 0;
4880
uint64_t recv_size;
49-
rpc::Client::Port port = rpc::client.open<opcode>();
81+
rpc::Client::Port port = rpc::client.open<RPC_READ_FROM_STREAM>();
5082
port.send([=](rpc::Buffer *buffer) {
5183
buffer->data[0] = size;
52-
if constexpr (opcode == RPC_READ_FROM_STREAM)
53-
buffer->data[1] = reinterpret_cast<uintptr_t>(file);
84+
buffer->data[1] = from_stream(file);
5485
});
5586
port.recv_n(&buf, &recv_size, [&](uint64_t) { return buf; });
5687
port.recv([&](rpc::Buffer *buffer) { ret = buffer->data[0]; });
@@ -59,43 +90,7 @@ LIBC_INLINE uint64_t read_from_stream(::FILE *file, void *buf, size_t size) {
5990
}
6091

6192
LIBC_INLINE uint64_t read(::FILE *f, void *data, size_t size) {
62-
if (f == stdin)
63-
return read_from_stream<RPC_READ_FROM_STDIN>(f, data, size);
64-
else
65-
return read_from_stream<RPC_READ_FROM_STREAM>(f, data, size);
66-
}
67-
68-
enum Stream {
69-
File = 0,
70-
Stdin = 1,
71-
Stdout = 2,
72-
Stderr = 3,
73-
};
74-
75-
// When copying between the client and server we need to indicate if this is one
76-
// of the special streams. We do this by enocding the low order bits of the
77-
// pointer to indicate if we need to use the host's standard stream.
78-
LIBC_INLINE uintptr_t from_stream(::FILE *f) {
79-
if (f == stdin)
80-
return reinterpret_cast<uintptr_t>(f) | Stdin;
81-
if (f == stdout)
82-
return reinterpret_cast<uintptr_t>(f) | Stdout;
83-
if (f == stderr)
84-
return reinterpret_cast<uintptr_t>(f) | Stderr;
85-
return reinterpret_cast<uintptr_t>(f);
86-
}
87-
88-
// Get the associated stream out of an encoded number.
89-
LIBC_INLINE ::FILE *to_stream(uintptr_t f) {
90-
::FILE *stream = reinterpret_cast<FILE *>(f & ~0x3ull);
91-
Stream type = static_cast<Stream>(f & 0x3ull);
92-
if (type == Stdin)
93-
return stdin;
94-
if (type == Stdout)
95-
return stdout;
96-
if (type == Stderr)
97-
return stderr;
98-
return stream;
93+
return read_from_stream(f, data, size);
9994
}
10095

10196
} // namespace file

libc/utils/gpu/server/rpc_server.cpp

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -80,23 +80,18 @@ struct Server {
8080
});
8181
break;
8282
}
83-
case RPC_READ_FROM_STREAM:
84-
case RPC_READ_FROM_STDIN: {
83+
case RPC_READ_FROM_STREAM: {
8584
uint64_t sizes[lane_size] = {0};
8685
void *data[lane_size] = {nullptr};
87-
uint64_t rets[lane_size] = {0};
8886
port->recv([&](rpc::Buffer *buffer, uint32_t id) {
89-
sizes[id] = buffer->data[0];
90-
data[id] = new char[sizes[id]];
91-
FILE *file = port->get_opcode() == RPC_READ_FROM_STREAM
92-
? reinterpret_cast<FILE *>(buffer->data[1])
93-
: stdin;
94-
rets[id] = fread(data[id], 1, sizes[id], file);
87+
data[id] = new char[buffer->data[0]];
88+
sizes[id] = fread(data[id], 1, buffer->data[0],
89+
file::to_stream(buffer->data[1]));
9590
});
9691
port->send_n(data, sizes);
9792
port->send([&](rpc::Buffer *buffer, uint32_t id) {
9893
delete[] reinterpret_cast<uint8_t *>(data[id]);
99-
std::memcpy(buffer->data, &rets[id], sizeof(uint64_t));
94+
std::memcpy(buffer->data, &sizes[id], sizeof(uint64_t));
10095
});
10196
break;
10297
}

0 commit comments

Comments
 (0)