Skip to content

Commit be33d31

Browse files
LamForestpytorchmergebot
authored andcommitted
add std::ostream& operator<< for BFloat16 in BFloat16.h (pytorch#121302)
This PR Move `operator<<` of `BFloat16` to `BFloat16.h`. Previously, this function is in `TensorDataContainer.h`. If need `std::cout` a `BFloat16` variable when debugging, `TensorDataContainer.h` have to be included. This is inconvient and counterintuitive. Other dtypes such as `Half`, define their `operator<<` in headers where they are defined such as `Half.h`. Therefore, I think it makes more sense to move `operator<<` of `BFloat16` to `BFloat16.h` Pull Request resolved: pytorch#121302 Approved by: https://github.com/ezyang
1 parent 5986552 commit be33d31

File tree

8 files changed

+24
-14
lines changed

8 files changed

+24
-14
lines changed

c10/util/BFloat16.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#include <cstdint>
99
#include <cstring>
1010

11+
#include <iosfwd>
12+
1113
#if defined(__CUDACC__) && !defined(USE_ROCM)
1214
#include <cuda_bf16.h>
1315
#endif
@@ -112,6 +114,8 @@ struct alignas(2) BFloat16 {
112114
#endif
113115
};
114116

117+
C10_API std::ostream& operator<<(std::ostream& out, const BFloat16& value);
118+
115119
} // namespace c10
116120

117121
#include <c10/util/BFloat16-inl.h> // IWYU pragma: keep

c10/util/Bfloat16.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#include <c10/util/BFloat16.h>
2+
#include <ostream>
3+
#include <type_traits>
4+
5+
namespace c10 {
6+
7+
static_assert(
8+
std::is_standard_layout_v<BFloat16>,
9+
"c10::BFloat16 must be standard layout.");
10+
11+
std::ostream& operator<<(std::ostream& out, const BFloat16& value) {
12+
out << (float)value;
13+
return out;
14+
}
15+
} // namespace c10

c10/util/Float8_e4m3fn.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#include <c10/util/Float8_e4m3fn.h>
2-
#include <iostream>
2+
#include <ostream>
33
#include <type_traits>
44

55
namespace c10 {

c10/util/Float8_e4m3fnuz.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#include <c10/util/Float8_e4m3fnuz.h>
2-
#include <iostream>
2+
#include <ostream>
33

44
namespace c10 {
55

c10/util/Float8_e5m2.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#include <c10/util/Float8_e5m2.h>
2-
#include <iostream>
2+
#include <ostream>
33

44
namespace c10 {
55

c10/util/Float8_e5m2fnuz.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#include <c10/util/Float8_e5m2fnuz.h>
2-
#include <iostream>
2+
#include <ostream>
33

44
namespace c10 {
55

c10/util/Half.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#include <c10/util/Half.h>
2-
#include <iostream>
2+
#include <ostream>
33
#include <type_traits>
44

55
namespace c10 {

torch/csrc/api/include/torch/detail/TensorDataContainer.h

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,6 @@ inline std::ostream& operator<<(
2828
std::ostream& stream,
2929
const TensorDataContainer& tensor_data_container);
3030

31-
// FIXME: There is no `operator<<` overload for `at::kBFloat16` type,
32-
// and we need to convert it to `float` type using `operator float()` function
33-
// defined in `c10/util/BFloat16.h`.
34-
// Tracking issue: https://github.com/pytorch/pytorch/issues/28845
35-
inline std::ostream& operator<<(std::ostream& stream, c10::BFloat16 value) {
36-
stream << static_cast<float>(value);
37-
return stream;
38-
}
39-
4031
inline c10::ScalarType compute_desired_dtype(c10::ScalarType scalar_type) {
4132
if (scalar_type == at::kInt || scalar_type == at::kLong) {
4233
// C++ `torch::tensor` with an integer type or an `at::ArrayRef` /

0 commit comments

Comments
 (0)