Skip to content

[DFSan] Fix sscanf checking that ordinary characters match. #95333

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 13, 2024

Conversation

browneee
Copy link
Contributor

Fixes: #94769

@llvmbot
Copy link
Member

llvmbot commented Jun 12, 2024

@llvm/pr-subscribers-compiler-rt-sanitizer

Author: Andrew Browne (browneee)

Changes

Fixes: #94769


Patch is 23.83 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/95333.diff

2 Files Affected:

  • (modified) compiler-rt/lib/dfsan/dfsan_custom.cpp (+291-231)
  • (modified) compiler-rt/test/dfsan/sscanf.c (+102-9)
diff --git a/compiler-rt/lib/dfsan/dfsan_custom.cpp b/compiler-rt/lib/dfsan/dfsan_custom.cpp
index af3c1f4d1673c..050f5232c0408 100644
--- a/compiler-rt/lib/dfsan/dfsan_custom.cpp
+++ b/compiler-rt/lib/dfsan/dfsan_custom.cpp
@@ -2198,50 +2198,12 @@ struct Formatter {
     return retval;
   }
 
-  int scan() {
-    char *tmp_fmt = build_format_string(true);
-    int read_count = 0;
-    int retval = sscanf(str + str_off, tmp_fmt, &read_count);
-    if (retval > 0) {
-      if (-1 == num_scanned)
-        num_scanned = 0;
-      num_scanned += retval;
-    }
-    free(tmp_fmt);
-    return read_count;
-  }
-
-  template <typename T>
-  int scan(T arg) {
-    char *tmp_fmt = build_format_string(true);
-    int read_count = 0;
-    int retval = sscanf(str + str_off, tmp_fmt, arg, &read_count);
-    if (retval > 0) {
-      if (-1 == num_scanned)
-        num_scanned = 0;
-      num_scanned += retval;
-    }
-    free(tmp_fmt);
-    return read_count;
-  }
-
-  // with_n -> toggles adding %n on/off; off by default
-  char *build_format_string(bool with_n = false) {
+  char *build_format_string() {
     size_t fmt_size = fmt_cur - fmt_start + 1;
-    size_t add_size = 0;
-    if (with_n)
-      add_size = 2;
-    char *new_fmt = (char *)malloc(fmt_size + 1 + add_size);
+    char *new_fmt = (char *)malloc(fmt_size + 1);
     assert(new_fmt);
     internal_memcpy(new_fmt, fmt_start, fmt_size);
-    if (!with_n) {
-      new_fmt[fmt_size] = '\0';
-    } else {
-      new_fmt[fmt_size] = '%';
-      new_fmt[fmt_size + 1] = 'n';
-      new_fmt[fmt_size + 2] = '\0';
-    }
-
+    new_fmt[fmt_size] = '\0';
     return new_fmt;
   }
 
@@ -2467,6 +2429,102 @@ static int format_buffer(char *str, size_t size, const char *fmt,
   return formatter.str_off;
 }
 
+// Scans a chunk either a constant string or a single format directive (e.g.,
+// '%.3f').
+struct Scanner {
+  Scanner(char *str_, const char *fmt_, size_t size_)
+      : str(str_),
+        str_off(0),
+        size(size_),
+        fmt_start(fmt_),
+        fmt_cur(fmt_),
+        width(-1),
+        num_scanned(0),
+        skip(false) {}
+
+  // Consumes a chunk of ordinary characters.
+  // Returns number of matching ordinary characters.
+  // Returns -1 if the match failed.
+  // In format strings, a space will match multiple spaces.
+  int check_match_ordinary() {
+    char *tmp_fmt = build_format_string_with_n();
+    int read_count = -1;
+    sscanf(str + str_off, tmp_fmt, &read_count);
+    free(tmp_fmt);
+    if (read_count > 0) {
+      str_off += read_count;
+    }
+    return read_count;
+  }
+
+  int scan() {
+    char *tmp_fmt = build_format_string_with_n();
+    int read_count = 0;
+    int retval = sscanf(str + str_off, tmp_fmt, &read_count);
+    free(tmp_fmt);
+    if (retval > 0) {
+      num_scanned += retval;
+    }
+    return read_count;
+  }
+
+  template <typename T>
+  int scan(T arg) {
+    char *tmp_fmt = build_format_string_with_n();
+    int read_count = 0;
+    int retval = sscanf(str + str_off, tmp_fmt, arg, &read_count);
+    free(tmp_fmt);
+    if (retval > 0) {
+      num_scanned += retval;
+    }
+    return read_count;
+  }
+
+  // Adds %n onto current format string to measure length.
+  char *build_format_string_with_n() {
+    size_t fmt_size = fmt_cur - fmt_start + 1;
+    // +2 for %n, +1 for \0
+    char *new_fmt = (char *)malloc(fmt_size + 2 + 1);
+    assert(new_fmt);
+    internal_memcpy(new_fmt, fmt_start, fmt_size);
+    new_fmt[fmt_size] = '%';
+    new_fmt[fmt_size + 1] = 'n';
+    new_fmt[fmt_size + 2] = '\0';
+    return new_fmt;
+  }
+
+  char *str_cur() { return str + str_off; }
+
+  size_t num_written_bytes(int retval) {
+    if (retval < 0) {
+      return 0;
+    }
+
+    size_t num_avail = str_off < size ? size - str_off : 0;
+    if (num_avail == 0) {
+      return 0;
+    }
+
+    size_t num_written = retval;
+    // A return value of {v,}snprintf of size or more means that the output was
+    // truncated.
+    if (num_written >= num_avail) {
+      num_written -= num_avail;
+    }
+
+    return num_written;
+  }
+
+  char *str;
+  size_t str_off;
+  size_t size;
+  const char *fmt_start;
+  const char *fmt_cur;
+  int width;
+  int num_scanned;
+  bool skip;
+};
+
 // This function is an inverse of format_buffer: we take the input buffer,
 // scan it in search for format strings and store the results in the varargs.
 // The labels are propagated from the input buffer to the varargs.
@@ -2474,220 +2532,222 @@ static int scan_buffer(char *str, size_t size, const char *fmt,
                        dfsan_label *va_labels, dfsan_label *ret_label,
                        dfsan_origin *str_origin, dfsan_origin *ret_origin,
                        va_list ap) {
-  Formatter formatter(str, fmt, size);
-  while (*formatter.fmt_cur) {
-    formatter.fmt_start = formatter.fmt_cur;
-    formatter.width = -1;
-    formatter.skip = false;
+  Scanner scanner(str, fmt, size);
+  while (*scanner.fmt_cur) {
+    scanner.fmt_start = scanner.fmt_cur;
+    scanner.width = -1;
+    scanner.skip = false;
     int read_count = 0;
     void *dst_ptr = 0;
     size_t write_size = 0;
-    if (*formatter.fmt_cur != '%') {
-      // Ordinary character. Consume all the characters until a '%' or the end
-      // of the string.
-      for (; *(formatter.fmt_cur + 1) && *(formatter.fmt_cur + 1) != '%';
-           ++formatter.fmt_cur) {
+    if (*scanner.fmt_cur != '%') {
+      // Ordinary character and spaces.
+      // Consume all the characters until a '%' or the end of the string.
+      for (; *(scanner.fmt_cur + 1) && *(scanner.fmt_cur + 1) != '%';
+           ++scanner.fmt_cur) {
+      }
+      if (scanner.check_match_ordinary() < 0) {
+        // The ordinary characters did not match.
+        break;
       }
-      read_count = formatter.scan();
-      dfsan_set_label(0, formatter.str_cur(),
-                      formatter.num_written_bytes(read_count));
     } else {
       // Conversion directive. Consume all the characters until a conversion
       // specifier or the end of the string.
       bool end_fmt = false;
-      for (; *formatter.fmt_cur && !end_fmt;) {
-        switch (*++formatter.fmt_cur) {
-        case 'd':
-        case 'i':
-        case 'o':
-        case 'u':
-        case 'x':
-        case 'X':
-          if (formatter.skip) {
-            read_count = formatter.scan();
-          } else {
-            switch (*(formatter.fmt_cur - 1)) {
-            case 'h':
-              // Also covers the 'hh' case (since the size of the arg is still
-              // an int).
-              dst_ptr = va_arg(ap, int *);
-              read_count = formatter.scan((int *)dst_ptr);
-              write_size = sizeof(int);
-              break;
-            case 'l':
-              if (formatter.fmt_cur - formatter.fmt_start >= 2 &&
-                  *(formatter.fmt_cur - 2) == 'l') {
-                dst_ptr = va_arg(ap, long long int *);
-                read_count = formatter.scan((long long int *)dst_ptr);
-                write_size = sizeof(long long int);
-              } else {
-                dst_ptr = va_arg(ap, long int *);
-                read_count = formatter.scan((long int *)dst_ptr);
-                write_size = sizeof(long int);
+      for (; *scanner.fmt_cur && !end_fmt;) {
+        switch (*++scanner.fmt_cur) {
+          case 'd':
+          case 'i':
+          case 'o':
+          case 'u':
+          case 'x':
+          case 'X':
+            if (scanner.skip) {
+              read_count = scanner.scan();
+            } else {
+              switch (*(scanner.fmt_cur - 1)) {
+                case 'h':
+                  // Also covers the 'hh' case (since the size of the arg is
+                  // still an int).
+                  dst_ptr = va_arg(ap, int *);
+                  read_count = scanner.scan((int *)dst_ptr);
+                  write_size = sizeof(int);
+                  break;
+                case 'l':
+                  if (scanner.fmt_cur - scanner.fmt_start >= 2 &&
+                      *(scanner.fmt_cur - 2) == 'l') {
+                    dst_ptr = va_arg(ap, long long int *);
+                    read_count = scanner.scan((long long int *)dst_ptr);
+                    write_size = sizeof(long long int);
+                  } else {
+                    dst_ptr = va_arg(ap, long int *);
+                    read_count = scanner.scan((long int *)dst_ptr);
+                    write_size = sizeof(long int);
+                  }
+                  break;
+                case 'q':
+                  dst_ptr = va_arg(ap, long long int *);
+                  read_count = scanner.scan((long long int *)dst_ptr);
+                  write_size = sizeof(long long int);
+                  break;
+                case 'j':
+                  dst_ptr = va_arg(ap, intmax_t *);
+                  read_count = scanner.scan((intmax_t *)dst_ptr);
+                  write_size = sizeof(intmax_t);
+                  break;
+                case 'z':
+                case 't':
+                  dst_ptr = va_arg(ap, size_t *);
+                  read_count = scanner.scan((size_t *)dst_ptr);
+                  write_size = sizeof(size_t);
+                  break;
+                default:
+                  dst_ptr = va_arg(ap, int *);
+                  read_count = scanner.scan((int *)dst_ptr);
+                  write_size = sizeof(int);
+              }
+              // get the label associated with the string at the corresponding
+              // place
+              dfsan_label l = dfsan_read_label(
+                  scanner.str_cur(), scanner.num_written_bytes(read_count));
+              dfsan_set_label(l, dst_ptr, write_size);
+              if (str_origin != nullptr) {
+                dfsan_set_label(l, dst_ptr, write_size);
+                size_t scan_count = scanner.num_written_bytes(read_count);
+                size_t size = scan_count > write_size ? write_size : scan_count;
+                dfsan_mem_origin_transfer(dst_ptr, scanner.str_cur(), size);
               }
-              break;
-            case 'q':
-              dst_ptr = va_arg(ap, long long int *);
-              read_count = formatter.scan((long long int *)dst_ptr);
-              write_size = sizeof(long long int);
-              break;
-            case 'j':
-              dst_ptr = va_arg(ap, intmax_t *);
-              read_count = formatter.scan((intmax_t *)dst_ptr);
-              write_size = sizeof(intmax_t);
-              break;
-            case 'z':
-            case 't':
-              dst_ptr = va_arg(ap, size_t *);
-              read_count = formatter.scan((size_t *)dst_ptr);
-              write_size = sizeof(size_t);
-              break;
-            default:
-              dst_ptr = va_arg(ap, int *);
-              read_count = formatter.scan((int *)dst_ptr);
-              write_size = sizeof(int);
-            }
-            // get the label associated with the string at the corresponding
-            // place
-            dfsan_label l = dfsan_read_label(
-                formatter.str_cur(), formatter.num_written_bytes(read_count));
-            dfsan_set_label(l, dst_ptr, write_size);
-            if (str_origin != nullptr) {
-            dfsan_set_label(l, dst_ptr, write_size);
-            size_t scan_count = formatter.num_written_bytes(read_count);
-            size_t size = scan_count > write_size ? write_size : scan_count;
-            dfsan_mem_origin_transfer(dst_ptr, formatter.str_cur(), size);
             }
-          }
-          end_fmt = true;
+            end_fmt = true;
 
-          break;
+            break;
 
-        case 'a':
-        case 'A':
-        case 'e':
-        case 'E':
-        case 'f':
-        case 'F':
-        case 'g':
-        case 'G':
-          if (formatter.skip) {
-            read_count = formatter.scan();
-          } else {
-            if (*(formatter.fmt_cur - 1) == 'L') {
-            dst_ptr = va_arg(ap, long double *);
-            read_count = formatter.scan((long double *)dst_ptr);
-            write_size = sizeof(long double);
-            } else if (*(formatter.fmt_cur - 1) == 'l') {
-            dst_ptr = va_arg(ap, double *);
-            read_count = formatter.scan((double *)dst_ptr);
-            write_size = sizeof(double);
+          case 'a':
+          case 'A':
+          case 'e':
+          case 'E':
+          case 'f':
+          case 'F':
+          case 'g':
+          case 'G':
+            if (scanner.skip) {
+              read_count = scanner.scan();
             } else {
-            dst_ptr = va_arg(ap, float *);
-            read_count = formatter.scan((float *)dst_ptr);
-            write_size = sizeof(float);
-            }
-            dfsan_label l = dfsan_read_label(
-                formatter.str_cur(), formatter.num_written_bytes(read_count));
-            dfsan_set_label(l, dst_ptr, write_size);
-            if (str_origin != nullptr) {
-            dfsan_set_label(l, dst_ptr, write_size);
-            size_t scan_count = formatter.num_written_bytes(read_count);
-            size_t size = scan_count > write_size ? write_size : scan_count;
-            dfsan_mem_origin_transfer(dst_ptr, formatter.str_cur(), size);
+              if (*(scanner.fmt_cur - 1) == 'L') {
+                dst_ptr = va_arg(ap, long double *);
+                read_count = scanner.scan((long double *)dst_ptr);
+                write_size = sizeof(long double);
+              } else if (*(scanner.fmt_cur - 1) == 'l') {
+                dst_ptr = va_arg(ap, double *);
+                read_count = scanner.scan((double *)dst_ptr);
+                write_size = sizeof(double);
+              } else {
+                dst_ptr = va_arg(ap, float *);
+                read_count = scanner.scan((float *)dst_ptr);
+                write_size = sizeof(float);
+              }
+              dfsan_label l = dfsan_read_label(
+                  scanner.str_cur(), scanner.num_written_bytes(read_count));
+              dfsan_set_label(l, dst_ptr, write_size);
+              if (str_origin != nullptr) {
+                dfsan_set_label(l, dst_ptr, write_size);
+                size_t scan_count = scanner.num_written_bytes(read_count);
+                size_t size = scan_count > write_size ? write_size : scan_count;
+                dfsan_mem_origin_transfer(dst_ptr, scanner.str_cur(), size);
+              }
             }
-          }
-          end_fmt = true;
-          break;
+            end_fmt = true;
+            break;
 
-        case 'c':
-          if (formatter.skip) {
-            read_count = formatter.scan();
-          } else {
-            dst_ptr = va_arg(ap, char *);
-            read_count = formatter.scan((char *)dst_ptr);
-            write_size = sizeof(char);
-            dfsan_label l = dfsan_read_label(
-                formatter.str_cur(), formatter.num_written_bytes(read_count));
-            dfsan_set_label(l, dst_ptr, write_size);
-            if (str_origin != nullptr) {
-            size_t scan_count = formatter.num_written_bytes(read_count);
-            size_t size = scan_count > write_size ? write_size : scan_count;
-            dfsan_mem_origin_transfer(dst_ptr, formatter.str_cur(), size);
+          case 'c':
+            if (scanner.skip) {
+              read_count = scanner.scan();
+            } else {
+              dst_ptr = va_arg(ap, char *);
+              read_count = scanner.scan((char *)dst_ptr);
+              write_size = sizeof(char);
+              dfsan_label l = dfsan_read_label(
+                  scanner.str_cur(), scanner.num_written_bytes(read_count));
+              dfsan_set_label(l, dst_ptr, write_size);
+              if (str_origin != nullptr) {
+                size_t scan_count = scanner.num_written_bytes(read_count);
+                size_t size = scan_count > write_size ? write_size : scan_count;
+                dfsan_mem_origin_transfer(dst_ptr, scanner.str_cur(), size);
+              }
             }
-          }
-          end_fmt = true;
-          break;
+            end_fmt = true;
+            break;
 
-        case 's': {
-          if (formatter.skip) {
-            read_count = formatter.scan();
-          } else {
-            dst_ptr = va_arg(ap, char *);
-            read_count = formatter.scan((char *)dst_ptr);
-            if (1 == read_count) {
-            // special case: we have parsed a single string and we need to
-            // update read_count with the string size
-            read_count = strlen((char *)dst_ptr);
+          case 's': {
+            if (scanner.skip) {
+              read_count = scanner.scan();
+            } else {
+              dst_ptr = va_arg(ap, char *);
+              read_count = scanner.scan((char *)dst_ptr);
+              if (1 == read_count) {
+                // special case: we have parsed a single string and we need to
+                // update read_count with the string size
+                read_count = strlen((char *)dst_ptr);
+              }
+              if (str_origin)
+                dfsan_mem_origin_transfer(
+                    dst_ptr, scanner.str_cur(),
+                    scanner.num_written_bytes(read_count));
+              va_labels++;
+              dfsan_mem_shadow_transfer(dst_ptr, scanner.str_cur(),
+                                        scanner.num_written_bytes(read_count));
             }
-            if (str_origin)
-            dfsan_mem_origin_transfer(dst_ptr, formatter.str_cur(),
-                                      formatter.num_written_bytes(read_count));
-            va_labels++;
-            dfsan_mem_shadow_transfer(dst_ptr, formatter.str_cur(),
-                                      formatter.num_written_bytes(read_count));
+            end_fmt = true;
+            break;
           }
-          end_fmt = true;
-          break;
-        }
 
-        case 'p':
-          if (formatter.skip) {
-            read_count = formatter.scan();
-          } else {
-            dst_ptr = va_arg(ap, void *);
-            read_count =
-                formatter.scan((int *)dst_ptr);  // note: changing void* to int*
+          case 'p':
+            if (scanner.skip) {
+              read_count = scanner.scan();
+            } else {
+              dst_ptr = va_arg(ap, void *);
+              read_count =
+                  scanner.scan((int *)dst_ptr);  // note: changing void* to int*
                                                  // since we need to call sizeof
-            write_size = sizeof(int);
-
-            dfsan_label l = dfsan_read_label(
-                formatter.str_cur(), formatter.num_written_bytes(read_count));
-            dfsan_set_label(l, dst_ptr, write_size);
-            if (str_origin != nullptr) {
-            dfsan_set_label(l, dst_ptr, write_size);
-            size_t scan_count = formatter.num_written_bytes(read_count);
-            size_t size = scan_count > write_size ? write_size : scan_count;
-            dfsan_mem_origin_transfer(dst_ptr, formatter.str_cur(), size);
+              write_size = sizeof(int);
+
+              dfsan_label l = dfsan_read_label(
+                  scanner.str_cur(), scanner.num_written_bytes(read_count));
+              dfsan_set_label(l, dst_ptr, write_size);
+              if (str_origin != nullptr) {
+                dfsan_set_label(l, dst_ptr, write_size);
+                size_t scan_count = scanner.num_written_bytes(read_count);
+                size_t size = scan_count > write_size ? write_size : scan_count;
+                dfsan_mem_origin_transfer(dst_ptr, scanner.str_cur(), size);
+              }
             }
-          }
-          end_fmt = true;
-          break;
+            end_fmt = true;
+            break;
 
-        case 'n': {
-          if (!formatter.skip) {
-            int *ptr = va_arg(ap, int *);
-            *ptr = (int)formatter.str_off;
-            *va_labels++ = 0;
-            dfsan_set_label(0, ptr, sizeof(*ptr));
-            if (str_origin != nullptr)
-            *str_origin++ = 0;
+          case 'n': {
+            if (!scanner.skip) {
+              int *ptr = va_arg(ap, int *);
+              *ptr = (int)scanner.str_off;
+              *va_labels++ = 0;
+              dfsan_set_label(0, ptr, sizeof(*ptr));
+              if (str_origin != nullptr)
+                *str_origin++ = 0;
+            }
+            end_fmt = true;
+            break;
     ...
[truncated]

Copy link
Contributor

@thurstond thurstond left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Andrew!

@browneee browneee merged commit cd94fa7 into llvm:main Jun 13, 2024
9 checks passed
EthanLuisMcDonough pushed a commit to EthanLuisMcDonough/llvm-project that referenced this pull request Aug 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[dfsan] sscanf function incorrectly ignores ordinary characters in the format string
3 participants