1use std::{io, time::Duration};
4
5use socket2::{SockRef, TcpKeepalive};
6use tokio::net::TcpStream;
7
8pub fn set_receive_buffer_size<'s, S>(socket: &'s S, size: usize) -> io::Result<()>
18where
19 SockRef<'s>: From<&'s S>,
20{
21 SockRef::from(socket).set_recv_buffer_size(size)
22}
23
24pub fn set_send_buffer_size<'s, S>(socket: &'s S, size: usize) -> io::Result<()>
34where
35 SockRef<'s>: From<&'s S>,
36{
37 SockRef::from(socket).set_send_buffer_size(size)
38}
39
40pub fn set_keepalive(socket: &TcpStream, ttl: Duration) -> io::Result<()> {
51 SockRef::from(socket).set_tcp_keepalive(&TcpKeepalive::new().with_time(ttl))
52}
53
54const SSL_R_PROTOCOL_IS_SHUTDOWN: std::ffi::c_int = 207;
57
58pub fn is_graceful_tls_shutdown(err: &io::Error) -> bool {
67 let Some(ssl) = err
68 .get_ref()
69 .and_then(|inner| inner.downcast_ref::<openssl::ssl::Error>())
70 else {
71 return false;
72 };
73 if ssl.code() == openssl::ssl::ErrorCode::ZERO_RETURN {
74 return true;
75 }
76 ssl.ssl_error().is_some_and(|stack| {
77 stack
78 .errors()
79 .iter()
80 .any(|e| e.reason_code() == SSL_R_PROTOCOL_IS_SHUTDOWN)
81 })
82}
83
84#[cfg(test)]
85mod tests {
86 use std::pin::Pin;
87
88 use openssl::ssl::{SslAcceptor, SslConnector, SslFiletype, SslMethod, SslVerifyMode};
89 use tokio::io::{AsyncReadExt, AsyncWriteExt};
90 use tokio::net::TcpListener;
91 use tokio_openssl::SslStream;
92
93 use crate::tls::{TEST_PEM_CA_PATH, TEST_PEM_CRT_PATH, TEST_PEM_KEY_PATH};
94
95 use super::{TcpStream, io, is_graceful_tls_shutdown};
96
97 #[test]
98 fn plain_io_errors_are_not_graceful() {
99 for err in [
100 io::Error::from(io::ErrorKind::BrokenPipe),
101 io::Error::from(io::ErrorKind::ConnectionReset),
102 io::Error::from(io::ErrorKind::UnexpectedEof),
103 io::Error::other("not an ssl error"),
104 ] {
105 assert!(
106 !is_graceful_tls_shutdown(&err),
107 "expected non-graceful, got graceful for {err:?}",
108 );
109 }
110 }
111
112 #[tokio::test]
121 async fn detects_graceful_shutdown_from_real_ssl_stream() {
122 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
123 let addr = listener.local_addr().unwrap();
124
125 let server = tokio::spawn(async move {
126 let (stream, _) = listener.accept().await.unwrap();
127 let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap();
128 acceptor
129 .set_private_key_file(TEST_PEM_KEY_PATH, SslFiletype::PEM)
130 .unwrap();
131 acceptor
132 .set_certificate_chain_file(TEST_PEM_CRT_PATH)
133 .unwrap();
134 let acceptor = acceptor.build();
135 let ssl = openssl::ssl::Ssl::new(acceptor.context()).unwrap();
136 let mut tls = SslStream::new(ssl, stream).unwrap();
137 Pin::new(&mut tls).accept().await.unwrap();
138 Pin::new(&mut tls).shutdown().await.unwrap();
140 });
141
142 let mut connector = SslConnector::builder(SslMethod::tls()).unwrap();
143 connector.set_ca_file(TEST_PEM_CA_PATH).unwrap();
144 connector.set_verify(SslVerifyMode::NONE);
145 let ssl = connector
146 .build()
147 .configure()
148 .unwrap()
149 .into_ssl("localhost")
150 .unwrap();
151 let stream = TcpStream::connect(addr).await.unwrap();
152 let mut tls = SslStream::new(ssl, stream).unwrap();
153 Pin::new(&mut tls).connect().await.unwrap();
154
155 let mut buf = [0u8; 1];
157 let n = tls.read(&mut buf).await.unwrap();
158 assert_eq!(n, 0, "expected EOF from peer's close_notify");
159
160 Pin::new(&mut tls).shutdown().await.unwrap();
164
165 let err = tls
166 .write_all(b"too late")
167 .await
168 .expect_err("write after bidirectional shutdown should fail");
169
170 assert!(
171 is_graceful_tls_shutdown(&err),
172 "expected graceful shutdown detection, got: {err:?} (inner: {:?})",
173 err.get_ref(),
174 );
175
176 server.await.unwrap();
177 }
178}