diff --git a/libavformat/tls_openssl.c b/libavformat/tls_openssl.c index 780874f8fe..eade59b5d5 100644 --- a/libavformat/tls_openssl.c +++ b/libavformat/tls_openssl.c @@ -23,6 +23,7 @@ #include "libavutil/mem.h" #include "network.h" #include "os_support.h" +#include "libavutil/time.h" #include "libavutil/random_seed.h" #include "url.h" #include "tls.h" @@ -32,7 +33,11 @@ #include #include #include +#if HAVE_SYS_TIME_H +#include +#endif +#define DTLS_HANDSHAKE_TIMEOUT_US 30000000 /** * Convert an EVP_PKEY to a PEM string. */ @@ -623,30 +628,62 @@ static void openssl_info_callback(const SSL *ssl, int where, int ret) { static int dtls_handshake(URLContext *h) { - int ret = 1, r0, r1; TLSContext *c = h->priv_data; + int ret, err; + int timeout_ms; + struct timeval timeout; + int64_t timeout_start = av_gettime_relative(); + int sockfd = ffurl_get_file_handle(c->tls_shared.udp); + struct pollfd pfd = { .fd = sockfd, .events = POLLIN, .revents = 0 }; - c->tls_shared.udp->flags &= ~AVIO_FLAG_NONBLOCK; + /* Force NONBLOCK mode to handle DTLS retransmissions */ + c->tls_shared.udp->flags |= AVIO_FLAG_NONBLOCK; - r0 = SSL_do_handshake(c->ssl); - if (r0 <= 0) { - r1 = SSL_get_error(c->ssl, r0); - - if (r1 != SSL_ERROR_WANT_READ && r1 != SSL_ERROR_WANT_WRITE && r1 != SSL_ERROR_ZERO_RETURN) { - av_log(c, AV_LOG_ERROR, "Handshake failed, r0=%d, r1=%d\n", r0, r1); - ret = print_ssl_error(h, r0); + for (;;) { + if (av_gettime_relative() - timeout_start > DTLS_HANDSHAKE_TIMEOUT_US) { + ret = AVERROR(ETIMEDOUT); goto end; } - } else { - av_log(c, AV_LOG_TRACE, "Handshake success, r0=%d\n", r0); - } + ret = SSL_do_handshake(c->ssl); + if (ret == 1) { + av_log(c, AV_LOG_TRACE, "Handshake success\n"); + break; + } + err = SSL_get_error(c->ssl, ret); + if (err != SSL_ERROR_WANT_READ && err != SSL_ERROR_WANT_WRITE && err != SSL_ERROR_ZERO_RETURN) { + av_log(c, AV_LOG_ERROR, "Handshake failed, ret=%d, err=%d\n", ret, err); + ret = print_ssl_error(h, ret); + goto end; + } + + timeout_ms = 1000; + if (DTLSv1_get_timeout(c->ssl, &timeout)) + timeout_ms = timeout.tv_sec * 1000 + timeout.tv_usec / 1000; + + ret = poll(&pfd, 1, timeout_ms); + if (ret > 0 && (pfd.revents & POLLIN)) + continue; + if (!ret) { + if (DTLSv1_handle_timeout(c->ssl) < 0) { + ret = AVERROR(EIO); + goto end; + } + continue; + } + if (ret < 0) { + ret = ff_neterrno(); + goto end; + } + } /* Check whether the handshake is completed. */ if (SSL_is_init_finished(c->ssl) != TLS_ST_OK) goto end; ret = 0; end: + if (!(h->flags & AVIO_FLAG_NONBLOCK)) + c->tls_shared.udp->flags &= ~AVIO_FLAG_NONBLOCK; return ret; }