24
24
* THE SOFTWARE.
25
25
*/
26
26
27
+ // Include strchrnul()
28
+ #define _GNU_SOURCE
29
+
27
30
#include <stdarg.h>
28
31
#include <string.h>
29
32
@@ -85,8 +88,8 @@ typedef struct {
85
88
char destination [256 ];
86
89
char header_key [64 ];
87
90
char header_value [256 ];
88
- // We store the origin so we can reply back with it.
89
- char origin [64 ];
91
+ char origin [ 64 ]; // We store the origin so we can reply back with it.
92
+ char host [64 ]; // We store the host to check against origin.
90
93
size_t content_length ;
91
94
size_t offset ;
92
95
uint64_t timestamp_ms ;
@@ -454,49 +457,33 @@ static bool _endswith(const char *str, const char *suffix) {
454
457
return strcmp (str + (strlen (str ) - strlen (suffix )), suffix ) == 0 ;
455
458
}
456
459
457
- const char * ok_hosts [] = {
458
- "127.0.0.1" ,
459
- "localhost" ,
460
- };
461
-
462
- static bool _origin_ok (const char * origin ) {
463
- const char * http = "http://" ;
460
+ const char http_scheme [] = "http://" ;
461
+ #define PREFIX_HTTP_LEN (sizeof(http_scheme) - 1)
464
462
465
- // note: redirected requests send an Origin of "null" and will be caught by this
466
- if (strncmp (origin , http , strlen (http )) != 0 ) {
467
- return false;
468
- }
469
- // These are prefix checks up to : so that any port works.
470
- // TODO: Support DHCP hostname in addition to MDNS.
471
- const char * end ;
472
- #if CIRCUITPY_MDNS
473
- if (!common_hal_mdns_server_deinited (& mdns )) {
474
- const char * local = ".local" ;
475
- const char * hostname = common_hal_mdns_server_get_hostname (& mdns );
476
- end = origin + strlen (http ) + strlen (hostname ) + strlen (local );
477
- if (strncmp (origin + strlen (http ), hostname , strlen (hostname )) == 0 &&
478
- strncmp (origin + strlen (http ) + strlen (hostname ), local , strlen (local )) == 0 &&
479
- (end [0 ] == '\0' || end [0 ] == ':' )) {
480
- return true;
481
- }
463
+ static bool _origin_ok (_request * request ) {
464
+ // Origin may be 'null'
465
+ if (request -> origin [0 ] == '\0' ) {
466
+ return true;
482
467
}
483
- #endif
484
-
485
- _update_encoded_ip ();
486
- end = origin + strlen (http ) + strlen (_our_ip_encoded );
487
- if (strncmp (origin + strlen (http ), _our_ip_encoded , strlen (_our_ip_encoded )) == 0 &&
488
- (end [0 ] == '\0' || end [0 ] == ':' )) {
468
+ // Origin has http prefix?
469
+ if (strncmp (request -> origin , http_scheme , PREFIX_HTTP_LEN ) != 0 ) {
470
+ // Not HTTP scheme request - ok
471
+ request -> origin [0 ] = '\0' ;
489
472
return true;
490
473
}
491
-
492
- for (size_t i = 0 ; i < MP_ARRAY_SIZE (ok_hosts ); i ++ ) {
493
- // Allows any port
494
- end = origin + strlen (http ) + strlen (ok_hosts [i ]);
495
- if (strncmp (origin + strlen (http ), ok_hosts [i ], strlen (ok_hosts [i ])) == 0
496
- && (end [0 ] == '\0' || end [0 ] == ':' )) {
474
+ // Host given?
475
+ if (request -> host [0 ] != '\0' ) {
476
+ // OK if host and origin match (fqdn + port #)
477
+ if (strcmp (request -> host , & request -> origin [PREFIX_HTTP_LEN ]) == 0 ) {
478
+ return true;
479
+ }
480
+ // DEBUG: OK if origin is 'localhost' (ignoring port #)
481
+ * strchrnul (& request -> origin [PREFIX_HTTP_LEN ], ':' ) = '\0' ;
482
+ if (strcmp (& request -> origin [PREFIX_HTTP_LEN ], "localhost" ) == 0 ) {
497
483
return true;
498
484
}
499
485
}
486
+ // Otherwise deny request
500
487
return false;
501
488
}
502
489
@@ -517,8 +504,8 @@ static void _cors_header(socketpool_socket_obj_t *socket, _request *request) {
517
504
_send_strs (socket ,
518
505
"Access-Control-Allow-Credentials: true\r\n" ,
519
506
"Vary: Origin, Accept, Upgrade\r\n" ,
520
- "Access-Control-Allow-Origin: *\r\n " ,
521
- NULL );
507
+ "Access-Control-Allow-Origin: " ,
508
+ ( request -> origin [ 0 ] == '\0' ) ? "*" : request -> origin , "\r\n" , NULL );
522
509
}
523
510
524
511
static void _reply_continue (socketpool_socket_obj_t * socket , _request * request ) {
@@ -1086,11 +1073,7 @@ static bool _reply(socketpool_socket_obj_t *socket, _request *request) {
1086
1073
#else
1087
1074
_reply_missing (socket , request );
1088
1075
#endif
1089
-
1090
- // For now until CORS is sorted, allow always the origin requester.
1091
- // Note: caller knows who we are better than us. CORS is not security
1092
- // unless browser cooperates. Do not rely on mDNS or IP.
1093
- } else if (strlen (request -> origin ) > 0 && !_origin_ok (request -> origin )) {
1076
+ } else if (!_origin_ok (request )) {
1094
1077
_reply_forbidden (socket , request );
1095
1078
} else if (strncmp (request -> path , "/fs/" , 4 ) == 0 ) {
1096
1079
if (strcasecmp (request -> method , "OPTIONS" ) == 0 ) {
@@ -1314,6 +1297,7 @@ static bool _reply(socketpool_socket_obj_t *socket, _request *request) {
1314
1297
static void _reset_request (_request * request ) {
1315
1298
request -> state = STATE_METHOD ;
1316
1299
request -> origin [0 ] = '\0' ;
1300
+ request -> host [0 ] = '\0' ;
1317
1301
request -> content_length = 0 ;
1318
1302
request -> offset = 0 ;
1319
1303
request -> timestamp_ms = 0 ;
@@ -1340,6 +1324,7 @@ static void _process_request(socketpool_socket_obj_t *socket, _request *request)
1340
1324
if (len == 0 || len == - MP_ENOTCONN ) {
1341
1325
// Disconnect - clear 'in-progress'
1342
1326
_reset_request (request );
1327
+ common_hal_socketpool_socket_close (socket );
1343
1328
}
1344
1329
break ;
1345
1330
}
@@ -1421,14 +1406,17 @@ static void _process_request(socketpool_socket_obj_t *socket, _request *request)
1421
1406
request -> redirect = strncmp (request -> header_value , cp_local , strlen (cp_local )) == 0 &&
1422
1407
(strlen (request -> header_value ) == strlen (cp_local ) ||
1423
1408
request -> header_value [strlen (cp_local )] == ':' );
1409
+ strncpy (request -> host , request -> header_value , sizeof (request -> host ) - 1 );
1410
+ request -> host [sizeof (request -> host ) - 1 ] = '\0' ;
1424
1411
} else if (strcasecmp (request -> header_key , "Content-Length" ) == 0 ) {
1425
1412
request -> content_length = strtoul (request -> header_value , NULL , 10 );
1426
1413
} else if (strcasecmp (request -> header_key , "Expect" ) == 0 ) {
1427
1414
request -> expect = strcmp (request -> header_value , "100-continue" ) == 0 ;
1428
1415
} else if (strcasecmp (request -> header_key , "Accept" ) == 0 ) {
1429
1416
request -> json = strcasecmp (request -> header_value , "application/json" ) == 0 ;
1430
1417
} else if (strcasecmp (request -> header_key , "Origin" ) == 0 ) {
1431
- strcpy (request -> origin , request -> header_value );
1418
+ strncpy (request -> origin , request -> header_value , sizeof (request -> origin ) - 1 );
1419
+ request -> origin [sizeof (request -> origin ) - 1 ] = '\0' ;
1432
1420
} else if (strcasecmp (request -> header_key , "X-Timestamp" ) == 0 ) {
1433
1421
request -> timestamp_ms = strtoull (request -> header_value , NULL , 10 );
1434
1422
} else if (strcasecmp (request -> header_key , "Upgrade" ) == 0 ) {
0 commit comments