Skip to content

Ensure orderly shutdown of ssl socket #7291

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions ports/espressif/common-hal/socketpool/Socket.c
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
#include "py/mperrno.h"
#include "py/runtime.h"
#include "shared-bindings/socketpool/SocketPool.h"
#include "shared-bindings/ssl/SSLSocket.h"
#include "common-hal/ssl/SSLSocket.h"
#include "supervisor/port.h"
#include "supervisor/shared/tick.h"
#include "supervisor/workflow.h"
Expand All @@ -44,7 +46,7 @@
StackType_t socket_select_stack[2 * configMINIMAL_STACK_SIZE];

STATIC int open_socket_fds[CONFIG_LWIP_MAX_SOCKETS];
STATIC bool user_socket[CONFIG_LWIP_MAX_SOCKETS];
STATIC socketpool_socket_obj_t *user_socket[CONFIG_LWIP_MAX_SOCKETS];
StaticTask_t socket_select_task_handle;
STATIC int socket_change_fd = -1;

Expand Down Expand Up @@ -117,7 +119,7 @@ void socket_user_reset(void) {

for (size_t i = 0; i < MP_ARRAY_SIZE(open_socket_fds); i++) {
open_socket_fds[i] = -1;
user_socket[i] = false;
user_socket[i] = NULL;
}
socket_change_fd = eventfd(0, 0);
// Run this at the same priority as CP so that the web workflow background task can be
Expand All @@ -134,12 +136,13 @@ void socket_user_reset(void) {

for (size_t i = 0; i < MP_ARRAY_SIZE(open_socket_fds); i++) {
if (open_socket_fds[i] >= 0 && user_socket[i]) {
common_hal_socketpool_socket_close(user_socket[i]);
int num = open_socket_fds[i];
// Close automatically clears socket handle
lwip_shutdown(num, SHUT_RDWR);
lwip_close(num);
open_socket_fds[i] = -1;
user_socket[i] = false;
user_socket[i] = NULL;
}
}
}
Expand Down Expand Up @@ -171,10 +174,10 @@ STATIC void unregister_open_socket(int fd) {
}
}

STATIC void mark_user_socket(int fd) {
STATIC void mark_user_socket(int fd, socketpool_socket_obj_t *obj) {
for (size_t i = 0; i < MP_ARRAY_SIZE(open_socket_fds); i++) {
if (open_socket_fds[i] == fd) {
user_socket[i] = true;
user_socket[i] = obj;
return;
}
}
Expand Down Expand Up @@ -236,7 +239,7 @@ socketpool_socket_obj_t *common_hal_socketpool_socket(socketpool_socketpool_obj_
if (!socketpool_socket(self, family, type, sock)) {
mp_raise_RuntimeError(translate("Out of sockets"));
}
mark_user_socket(sock->num);
mark_user_socket(sock->num, sock);
return sock;
}

Expand Down Expand Up @@ -292,12 +295,12 @@ int socketpool_socket_accept(socketpool_socket_obj_t *self, uint8_t *ip, uint32_

socketpool_socket_obj_t *common_hal_socketpool_socket_accept(socketpool_socket_obj_t *self,
uint8_t *ip, uint32_t *port) {
socketpool_socket_obj_t *sock = m_new_obj_with_finaliser(socketpool_socket_obj_t);
int newsoc = socketpool_socket_accept(self, ip, port, NULL);

if (newsoc > 0) {
mark_user_socket(newsoc);
// Create the socket
socketpool_socket_obj_t *sock = m_new_obj_with_finaliser(socketpool_socket_obj_t);
mark_user_socket(newsoc, sock);
sock->base.type = &socketpool_socket_type;
sock->num = newsoc;
sock->pool = self->pool;
Expand Down Expand Up @@ -338,6 +341,12 @@ bool common_hal_socketpool_socket_bind(socketpool_socket_obj_t *self,
}

void socketpool_socket_close(socketpool_socket_obj_t *self) {
if (self->ssl_socket) {
ssl_sslsocket_obj_t *ssl_socket = self->ssl_socket;
self->ssl_socket = NULL;
common_hal_ssl_sslsocket_close(ssl_socket);
return;
}
self->connected = false;
if (self->num >= 0) {
lwip_shutdown(self->num, SHUT_RDWR);
Expand Down
8 changes: 4 additions & 4 deletions ports/espressif/common-hal/socketpool/Socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@
* THE SOFTWARE.
*/

#ifndef MICROPY_INCLUDED_ESPRESSIF_COMMON_HAL_SOCKETPOOL_SOCKET_H
#define MICROPY_INCLUDED_ESPRESSIF_COMMON_HAL_SOCKETPOOL_SOCKET_H
#pragma once

#include "py/obj.h"

Expand All @@ -34,6 +33,8 @@

#include "components/esp-tls/esp_tls.h"

typedef struct ssl_sslsocket_obj ssl_sslsocket_obj_t;

typedef struct {
mp_obj_base_t base;
int num;
Expand All @@ -42,9 +43,8 @@ typedef struct {
int ipproto;
bool connected;
socketpool_socketpool_obj_t *pool;
ssl_sslsocket_obj_t *ssl_socket;
mp_uint_t timeout_ms;
} socketpool_socket_obj_t;

void socket_user_reset(void);

#endif // MICROPY_INCLUDED_ESPRESSIF_COMMON_HAL_SOCKETPOOL_SOCKET_H
1 change: 1 addition & 0 deletions ports/espressif/common-hal/ssl/SSLContext.c
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ ssl_sslsocket_obj_t *common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t
sock->base.type = &ssl_sslsocket_type;
sock->ssl_context = self;
sock->sock = socket;
socket->ssl_socket = sock;

// Create a copy of the ESP-TLS config object and store the server hostname
// Note that ESP-TLS will use common_name for both SNI and verification
Expand Down
2 changes: 1 addition & 1 deletion ports/espressif/common-hal/ssl/SSLSocket.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

#include "components/esp-tls/esp_tls.h"

typedef struct {
typedef struct ssl_sslsocket_obj {
mp_obj_base_t base;
socketpool_socket_obj_t *sock;
esp_tls_t *tls;
Expand Down