vector/
net.rs

1//! Networking-related helper functions.
2
3use std::{io, time::Duration};
4
5use socket2::{SockRef, TcpKeepalive};
6use tokio::net::TcpStream;
7
8/// Sets the receive buffer size for a socket.
9///
10/// This is the equivalent of setting the `SO_RCVBUF` socket setting directly.
11///
12/// # Errors
13///
14/// If there is an error setting the receive buffer size on the given socket, or if the value given
15/// as the socket is not a valid socket, an error variant will be returned explaining the underlying
16/// I/O error.
17pub fn set_receive_buffer_size<'s, S>(socket: &'s S, size: usize) -> io::Result<()>
18where
19    SockRef<'s>: From<&'s S>,
20{
21    SockRef::from(socket).set_recv_buffer_size(size)
22}
23
24/// Sets the send buffer size for a socket.
25///
26/// This is the equivalent of setting the `SO_SNDBUF` socket setting directly.
27///
28/// # Errors
29///
30/// If there is an error setting the send buffer size on the given socket, or if the value given
31/// as the socket is not a valid socket, an error variant will be returned explaining the underlying
32/// I/O error.
33pub fn set_send_buffer_size<'s, S>(socket: &'s S, size: usize) -> io::Result<()>
34where
35    SockRef<'s>: From<&'s S>,
36{
37    SockRef::from(socket).set_send_buffer_size(size)
38}
39
40/// Sets the TCP keepalive behavior on a socket.
41///
42/// This is the equivalent of setting the `SO_KEEPALIVE` and `TCP_KEEPALIVE` socket settings
43/// directly.
44///
45/// # Errors
46///
47/// If there is an error with either enabling keepalive probes or setting the TCP keepalive idle
48/// timeout on the given socket, an error variant will be returned explaining the underlying I/O
49/// error.
50pub fn set_keepalive(socket: &TcpStream, ttl: Duration) -> io::Result<()> {
51    SockRef::from(socket).set_tcp_keepalive(&TcpKeepalive::new().with_time(ttl))
52}
53
54// SSL_R_PROTOCOL_IS_SHUTDOWN from openssl/include/openssl/sslerr.h. Stable across
55// OpenSSL 1.1.1 and 3.x. Not re-exported by the `openssl-sys` crate so we pin it here.
56const SSL_R_PROTOCOL_IS_SHUTDOWN: std::ffi::c_int = 207;
57
58/// Returns true when an `io::Error` represents a peer-initiated, graceful TLS
59/// shutdown (close_notify), rather than a real I/O failure.
60///
61/// Two cases are recognized:
62/// - `SSL_ERROR_ZERO_RETURN`: the peer sent `close_notify` and we observed it
63///   during this I/O call.
64/// - `SSL_R_PROTOCOL_IS_SHUTDOWN`: a subsequent write after the session was
65///   already shut down ("ssl session has been shut down").
66pub fn is_graceful_tls_shutdown(err: &io::Error) -> bool {
67    let Some(ssl) = err
68        .get_ref()
69        .and_then(|inner| inner.downcast_ref::<openssl::ssl::Error>())
70    else {
71        return false;
72    };
73    if ssl.code() == openssl::ssl::ErrorCode::ZERO_RETURN {
74        return true;
75    }
76    ssl.ssl_error().is_some_and(|stack| {
77        stack
78            .errors()
79            .iter()
80            .any(|e| e.reason_code() == SSL_R_PROTOCOL_IS_SHUTDOWN)
81    })
82}
83
84#[cfg(test)]
85mod tests {
86    use std::pin::Pin;
87
88    use openssl::ssl::{SslAcceptor, SslConnector, SslFiletype, SslMethod, SslVerifyMode};
89    use tokio::io::{AsyncReadExt, AsyncWriteExt};
90    use tokio::net::TcpListener;
91    use tokio_openssl::SslStream;
92
93    use crate::tls::{TEST_PEM_CA_PATH, TEST_PEM_CRT_PATH, TEST_PEM_KEY_PATH};
94
95    use super::{TcpStream, io, is_graceful_tls_shutdown};
96
97    #[test]
98    fn plain_io_errors_are_not_graceful() {
99        for err in [
100            io::Error::from(io::ErrorKind::BrokenPipe),
101            io::Error::from(io::ErrorKind::ConnectionReset),
102            io::Error::from(io::ErrorKind::UnexpectedEof),
103            io::Error::other("not an ssl error"),
104        ] {
105            assert!(
106                !is_graceful_tls_shutdown(&err),
107                "expected non-graceful, got graceful for {err:?}",
108            );
109        }
110    }
111
112    // Drives a real TLS handshake between two local sockets and completes a
113    // bidirectional SSL shutdown. A subsequent write surfaces a `std::io::Error`
114    // wrapping an `openssl::ssl::Error` from the same code path production hits,
115    // validating that the helper correctly identifies it as a graceful shutdown
116    // — without having to synthesize an `openssl::ssl::Error` (whose fields are
117    // crate-private). Bidirectional shutdown is what reliably elicits
118    // SSL_R_PROTOCOL_IS_SHUTDOWN; a half-closed session would still permit
119    // writes per RFC 5246.
120    #[tokio::test]
121    async fn detects_graceful_shutdown_from_real_ssl_stream() {
122        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
123        let addr = listener.local_addr().unwrap();
124
125        let server = tokio::spawn(async move {
126            let (stream, _) = listener.accept().await.unwrap();
127            let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap();
128            acceptor
129                .set_private_key_file(TEST_PEM_KEY_PATH, SslFiletype::PEM)
130                .unwrap();
131            acceptor
132                .set_certificate_chain_file(TEST_PEM_CRT_PATH)
133                .unwrap();
134            let acceptor = acceptor.build();
135            let ssl = openssl::ssl::Ssl::new(acceptor.context()).unwrap();
136            let mut tls = SslStream::new(ssl, stream).unwrap();
137            Pin::new(&mut tls).accept().await.unwrap();
138            // Cleanly close the SSL session — sends close_notify and waits for the peer's.
139            Pin::new(&mut tls).shutdown().await.unwrap();
140        });
141
142        let mut connector = SslConnector::builder(SslMethod::tls()).unwrap();
143        connector.set_ca_file(TEST_PEM_CA_PATH).unwrap();
144        connector.set_verify(SslVerifyMode::NONE);
145        let ssl = connector
146            .build()
147            .configure()
148            .unwrap()
149            .into_ssl("localhost")
150            .unwrap();
151        let stream = TcpStream::connect(addr).await.unwrap();
152        let mut tls = SslStream::new(ssl, stream).unwrap();
153        Pin::new(&mut tls).connect().await.unwrap();
154
155        // Drain the server's close_notify so our SSL state observes the peer shutdown.
156        let mut buf = [0u8; 1];
157        let n = tls.read(&mut buf).await.unwrap();
158        assert_eq!(n, 0, "expected EOF from peer's close_notify");
159
160        // Complete the bidirectional SSL shutdown locally. Once both sides are
161        // shut down, OpenSSL marks the session as SHUTDOWN and any further write
162        // returns SSL_R_PROTOCOL_IS_SHUTDOWN ("ssl session has been shut down").
163        Pin::new(&mut tls).shutdown().await.unwrap();
164
165        let err = tls
166            .write_all(b"too late")
167            .await
168            .expect_err("write after bidirectional shutdown should fail");
169
170        assert!(
171            is_graceful_tls_shutdown(&err),
172            "expected graceful shutdown detection, got: {err:?} (inner: {:?})",
173            err.get_ref(),
174        );
175
176        server.await.unwrap();
177    }
178}