vector/test_util/
mod.rs

1#![allow(missing_docs)]
2
3use std::{
4    collections::HashMap,
5    convert::Infallible,
6    fs::File,
7    future::{Future, ready},
8    io::Read,
9    iter,
10    net::SocketAddr,
11    path::{Path, PathBuf},
12    pin::Pin,
13    sync::{
14        Arc,
15        atomic::{AtomicUsize, Ordering},
16    },
17    task::{Context, Poll, ready},
18};
19
20use chrono::{DateTime, SubsecRound, Utc};
21use flate2::read::MultiGzDecoder;
22use futures::{FutureExt, SinkExt, Stream, StreamExt, TryStreamExt, stream, task::noop_waker_ref};
23use openssl::ssl::{SslConnector, SslFiletype, SslMethod, SslVerifyMode};
24use rand::{Rng, rng};
25use rand_distr::Alphanumeric;
26use tokio::{
27    io::{AsyncRead, AsyncWrite, AsyncWriteExt, Result as IoResult},
28    net::{TcpListener, TcpStream, ToSocketAddrs},
29    runtime,
30    sync::oneshot,
31    task::JoinHandle,
32    time::{Duration, Instant, sleep},
33};
34use tokio_stream::wrappers::TcpListenerStream;
35#[cfg(unix)]
36use tokio_stream::wrappers::UnixListenerStream;
37use tokio_util::codec::{Encoder, FramedRead, FramedWrite, LinesCodec};
38use vector_lib::{
39    buffers::topology::channel::LimitedReceiver,
40    event::{
41        BatchNotifier, BatchStatusReceiver, Event, EventArray, LogEvent, Metric, MetricKind,
42        MetricTags, MetricValue,
43    },
44};
45#[cfg(test)]
46use zstd::Decoder as ZstdDecoder;
47
48use crate::{
49    config::{Config, GenerateConfig},
50    topology::{RunningTopology, ShutdownErrorReceiver},
51    trace,
52};
53
54const WAIT_FOR_SECS: u64 = 5; // The default time to wait in `wait_for`
55const WAIT_FOR_MIN_MILLIS: u64 = 5; // The minimum time to pause before retrying
56const WAIT_FOR_MAX_MILLIS: u64 = 500; // The maximum time to pause before retrying
57
58pub mod addr;
59pub mod compression;
60pub mod stats;
61
62#[cfg(any(test, feature = "test-utils"))]
63pub mod components;
64#[cfg(test)]
65pub mod http;
66#[cfg(test)]
67pub mod integration;
68#[cfg(test)]
69pub mod metrics;
70#[cfg(test)]
71pub mod mock;
72
73#[macro_export]
74macro_rules! assert_downcast_matches {
75    ($e:expr_2021, $t:ty, $v:pat) => {{
76        match $e.downcast_ref::<$t>() {
77            Some($v) => (),
78            got => panic!("Assertion failed: got wrong error variant {:?}", got),
79        }
80    }};
81}
82
83#[macro_export]
84macro_rules! log_event {
85    ($($key:expr_2021 => $value:expr_2021),*  $(,)?) => {
86        #[allow(unused_variables)]
87        {
88            let mut event = $crate::event::Event::Log($crate::event::LogEvent::default());
89            let log = event.as_mut_log();
90            $(
91                log.insert($key, $value);
92            )*
93            event
94        }
95    };
96}
97
98pub fn test_generate_config<T>()
99where
100    for<'de> T: GenerateConfig + serde::Deserialize<'de>,
101{
102    let cfg = toml::to_string(&T::generate_config()).unwrap();
103
104    toml::from_str::<T>(&cfg)
105        .unwrap_or_else(|e| panic!("Invalid config generated from string:\n\n{e}\n'{cfg}'"));
106}
107
108pub fn open_fixture(path: impl AsRef<Path>) -> crate::Result<serde_json::Value> {
109    let test_file = File::open(path)?;
110    let value: serde_json::Value = serde_json::from_reader(test_file)?;
111    Ok(value)
112}
113
114pub fn trace_init() {
115    #[cfg(unix)]
116    let color = {
117        use std::io::IsTerminal;
118        std::io::stdout().is_terminal()
119            || std::env::var("NEXTEST")
120                .ok()
121                .and(Some(true))
122                .unwrap_or(false)
123    };
124    // Windows: ANSI colors are not supported by cmd.exe
125    // Color is false for everything except unix.
126    #[cfg(not(unix))]
127    let color = false;
128
129    let levels = std::env::var("VECTOR_LOG").unwrap_or_else(|_| "error".to_string());
130
131    trace::init(color, false, &levels, 10);
132
133    // Initialize metrics as well
134    vector_lib::metrics::init_test();
135}
136
137pub async fn send_lines(
138    addr: SocketAddr,
139    lines: impl IntoIterator<Item = String>,
140) -> Result<SocketAddr, Infallible> {
141    send_encodable(addr, LinesCodec::new(), lines).await
142}
143
144pub async fn send_encodable<I, E: From<std::io::Error> + std::fmt::Debug>(
145    addr: SocketAddr,
146    encoder: impl Encoder<I, Error = E>,
147    lines: impl IntoIterator<Item = I>,
148) -> Result<SocketAddr, Infallible> {
149    let stream = TcpStream::connect(&addr).await.unwrap();
150
151    let local_addr = stream.local_addr().unwrap();
152
153    let mut sink = FramedWrite::new(stream, encoder);
154
155    let mut lines = stream::iter(lines.into_iter()).map(Ok);
156    sink.send_all(&mut lines).await.unwrap();
157
158    let stream = sink.get_mut();
159    stream.shutdown().await.unwrap();
160
161    Ok(local_addr)
162}
163
164pub async fn send_lines_tls(
165    addr: SocketAddr,
166    host: String,
167    lines: impl Iterator<Item = String>,
168    ca: impl Into<Option<&Path>>,
169    client_cert: impl Into<Option<&Path>>,
170    client_key: impl Into<Option<&Path>>,
171) -> Result<SocketAddr, Infallible> {
172    let stream = TcpStream::connect(&addr).await.unwrap();
173
174    let local_addr = stream.local_addr().unwrap();
175
176    let mut connector = SslConnector::builder(SslMethod::tls()).unwrap();
177    if let Some(ca) = ca.into() {
178        connector.set_ca_file(ca).unwrap();
179    } else {
180        connector.set_verify(SslVerifyMode::NONE);
181    }
182
183    if let Some(cert_file) = client_cert.into() {
184        connector.set_certificate_chain_file(cert_file).unwrap();
185    }
186
187    if let Some(key_file) = client_key.into() {
188        connector
189            .set_private_key_file(key_file, SslFiletype::PEM)
190            .unwrap();
191    }
192
193    let ssl = connector
194        .build()
195        .configure()
196        .unwrap()
197        .into_ssl(&host)
198        .unwrap();
199
200    let mut stream = tokio_openssl::SslStream::new(ssl, stream).unwrap();
201    Pin::new(&mut stream).connect().await.unwrap();
202    let mut sink = FramedWrite::new(stream, LinesCodec::new());
203
204    let mut lines = stream::iter(lines).map(Ok);
205    sink.send_all(&mut lines).await.unwrap();
206
207    let stream = sink.get_mut().get_mut();
208    stream.shutdown().await.unwrap();
209
210    Ok(local_addr)
211}
212
213pub fn temp_file() -> PathBuf {
214    let path = std::env::temp_dir();
215    let file_name = random_string(16);
216    path.join(file_name + ".log")
217}
218
219pub fn temp_dir() -> PathBuf {
220    let path = std::env::temp_dir();
221    let dir_name = random_string(16);
222    path.join(dir_name)
223}
224
225pub fn random_table_name() -> String {
226    format!("test_{}", random_string(10).to_lowercase())
227}
228
229pub fn map_event_batch_stream(
230    stream: impl Stream<Item = Event>,
231    batch: Option<BatchNotifier>,
232) -> impl Stream<Item = EventArray> {
233    stream.map(move |event| event.with_batch_notifier_option(&batch).into())
234}
235
236// TODO refactor to have a single implementation for `Event`, `LogEvent` and `Metric`.
237fn map_batch_stream(
238    stream: impl Stream<Item = LogEvent>,
239    batch: Option<BatchNotifier>,
240) -> impl Stream<Item = EventArray> {
241    stream.map(move |log| vec![log.with_batch_notifier_option(&batch)].into())
242}
243
244pub fn generate_lines_with_stream<Gen: FnMut(usize) -> String>(
245    generator: Gen,
246    count: usize,
247    batch: Option<BatchNotifier>,
248) -> (Vec<String>, impl Stream<Item = EventArray>) {
249    let lines = (0..count).map(generator).collect::<Vec<_>>();
250    let stream = map_batch_stream(
251        stream::iter(lines.clone()).map(LogEvent::from_str_legacy),
252        batch,
253    );
254    (lines, stream)
255}
256
257pub fn random_lines_with_stream(
258    len: usize,
259    count: usize,
260    batch: Option<BatchNotifier>,
261) -> (Vec<String>, impl Stream<Item = EventArray>) {
262    let generator = move |_| random_string(len);
263    generate_lines_with_stream(generator, count, batch)
264}
265
266pub fn generate_events_with_stream<Gen: FnMut(usize) -> Event>(
267    generator: Gen,
268    count: usize,
269    batch: Option<BatchNotifier>,
270) -> (Vec<Event>, impl Stream<Item = EventArray>) {
271    let events = (0..count).map(generator).collect::<Vec<_>>();
272    let stream = map_batch_stream(
273        stream::iter(events.clone()).map(|event| event.into_log()),
274        batch,
275    );
276    (events, stream)
277}
278
279pub fn random_metrics_with_stream(
280    count: usize,
281    batch: Option<BatchNotifier>,
282    tags: Option<MetricTags>,
283) -> (Vec<Event>, impl Stream<Item = EventArray>) {
284    random_metrics_with_stream_timestamp(
285        count,
286        batch,
287        tags,
288        Utc::now().trunc_subsecs(3),
289        std::time::Duration::from_secs(2),
290    )
291}
292
293/// Generates event metrics with the provided tags and timestamp.
294///
295/// # Parameters
296/// - `count`: the number of metrics to generate
297/// - `batch`: the batch notifier to use with the stream
298/// - `tags`: the tags to apply to each metric event
299/// - `timestamp`: the timestamp to use for each metric event
300/// - `timestamp_offset`: the offset from the `timestamp` to use for each additional metric
301///
302/// # Returns
303/// A tuple of the generated metric events and the stream of the generated events
304pub fn random_metrics_with_stream_timestamp(
305    count: usize,
306    batch: Option<BatchNotifier>,
307    tags: Option<MetricTags>,
308    timestamp: DateTime<Utc>,
309    timestamp_offset: std::time::Duration,
310) -> (Vec<Event>, impl Stream<Item = EventArray>) {
311    let events: Vec<_> = (0..count)
312        .map(|index| {
313            let ts = timestamp + (timestamp_offset * index as u32);
314            Event::Metric(
315                Metric::new(
316                    format!("counter_{}", rng().random::<u32>()),
317                    MetricKind::Incremental,
318                    MetricValue::Counter {
319                        value: index as f64,
320                    },
321                )
322                .with_timestamp(Some(ts))
323                .with_tags(tags.clone()),
324            )
325            // this ensures we get Origin Metadata, with an undefined service but that's ok.
326            .with_source_type("a_source_like_none_other")
327        })
328        .collect();
329
330    let stream = map_event_batch_stream(stream::iter(events.clone()), batch);
331    (events, stream)
332}
333
334pub fn random_events_with_stream(
335    len: usize,
336    count: usize,
337    batch: Option<BatchNotifier>,
338) -> (Vec<Event>, impl Stream<Item = EventArray>) {
339    let events = (0..count)
340        .map(|_| Event::from(LogEvent::from_str_legacy(random_string(len))))
341        .collect::<Vec<_>>();
342    let stream = map_batch_stream(
343        stream::iter(events.clone()).map(|event| event.into_log()),
344        batch,
345    );
346    (events, stream)
347}
348
349pub fn random_updated_events_with_stream<F>(
350    len: usize,
351    count: usize,
352    batch: Option<BatchNotifier>,
353    update_fn: F,
354) -> (Vec<Event>, impl Stream<Item = EventArray>)
355where
356    F: Fn((usize, LogEvent)) -> LogEvent,
357{
358    let events = (0..count)
359        .map(|_| LogEvent::from_str_legacy(random_string(len)))
360        .enumerate()
361        .map(update_fn)
362        .map(Event::Log)
363        .collect::<Vec<_>>();
364    let stream = map_batch_stream(
365        stream::iter(events.clone()).map(|event| event.into_log()),
366        batch,
367    );
368    (events, stream)
369}
370
371pub fn create_events_batch_with_fn<F: Fn() -> Event>(
372    create_event_fn: F,
373    num_events: usize,
374) -> (Vec<Event>, BatchStatusReceiver) {
375    let mut events = (0..num_events)
376        .map(|_| create_event_fn())
377        .collect::<Vec<_>>();
378    let receiver = BatchNotifier::apply_to(&mut events);
379    (events, receiver)
380}
381
382pub fn random_string(len: usize) -> String {
383    rng()
384        .sample_iter(&Alphanumeric)
385        .take(len)
386        .map(char::from)
387        .collect::<String>()
388}
389
390pub fn random_lines(len: usize) -> impl Iterator<Item = String> {
391    iter::repeat_with(move || random_string(len))
392}
393
394pub fn random_map(max_size: usize, field_len: usize) -> HashMap<String, String> {
395    let size = rng().random_range(0..max_size);
396
397    (0..size)
398        .map(move |_| (random_string(field_len), random_string(field_len)))
399        .collect()
400}
401
402pub fn random_maps(
403    max_size: usize,
404    field_len: usize,
405) -> impl Iterator<Item = HashMap<String, String>> {
406    iter::repeat_with(move || random_map(max_size, field_len))
407}
408
409pub async fn collect_n<S>(rx: S, n: usize) -> Vec<S::Item>
410where
411    S: Stream,
412{
413    rx.take(n).collect().await
414}
415
416pub async fn collect_n_stream<T, S: Stream<Item = T> + Unpin>(stream: &mut S, n: usize) -> Vec<T> {
417    let mut events = Vec::with_capacity(n);
418
419    while events.len() < n {
420        let e = stream.next().await.unwrap();
421        events.push(e);
422    }
423    events
424}
425
426pub async fn collect_ready<S>(mut rx: S) -> Vec<S::Item>
427where
428    S: Stream + Unpin,
429{
430    let waker = noop_waker_ref();
431    let mut cx = Context::from_waker(waker);
432
433    let mut vec = Vec::new();
434    loop {
435        match rx.poll_next_unpin(&mut cx) {
436            Poll::Ready(Some(item)) => vec.push(item),
437            Poll::Ready(None) | Poll::Pending => return vec,
438        }
439    }
440}
441
442pub async fn collect_limited<T: Send + 'static>(mut rx: LimitedReceiver<T>) -> Vec<T> {
443    let mut items = Vec::new();
444    while let Some(item) = rx.next().await {
445        items.push(item);
446    }
447    items
448}
449
450pub async fn collect_n_limited<T: Send + 'static>(mut rx: LimitedReceiver<T>, n: usize) -> Vec<T> {
451    let mut items = Vec::new();
452    while items.len() < n {
453        match rx.next().await {
454            Some(item) => items.push(item),
455            None => break,
456        }
457    }
458    items
459}
460
461pub fn lines_from_file<P: AsRef<Path>>(path: P) -> Vec<String> {
462    trace!(message = "Reading file.", path = %path.as_ref().display());
463    let mut file = File::open(path).unwrap();
464    let mut output = String::new();
465    file.read_to_string(&mut output).unwrap();
466    output.lines().map(|s| s.to_owned()).collect()
467}
468
469pub fn lines_from_gzip_file<P: AsRef<Path>>(path: P) -> Vec<String> {
470    trace!(message = "Reading gzip file.", path = %path.as_ref().display());
471    let mut file = File::open(path).unwrap();
472    let mut gzip_bytes = Vec::new();
473    file.read_to_end(&mut gzip_bytes).unwrap();
474    let mut output = String::new();
475    MultiGzDecoder::new(&gzip_bytes[..])
476        .read_to_string(&mut output)
477        .unwrap();
478    output.lines().map(|s| s.to_owned()).collect()
479}
480
481#[cfg(test)]
482pub fn lines_from_zstd_file<P: AsRef<Path>>(path: P) -> Vec<String> {
483    trace!(message = "Reading zstd file.", path = %path.as_ref().display());
484    let file = File::open(path).unwrap();
485    let mut output = String::new();
486    ZstdDecoder::new(file)
487        .unwrap()
488        .read_to_string(&mut output)
489        .unwrap();
490    output.lines().map(|s| s.to_owned()).collect()
491}
492
493pub fn runtime() -> runtime::Runtime {
494    runtime::Builder::new_multi_thread()
495        .enable_all()
496        .build()
497        .unwrap()
498}
499
500// Wait for a Future to resolve, or the duration to elapse (will panic)
501pub async fn wait_for_duration<F, Fut>(mut f: F, duration: Duration)
502where
503    F: FnMut() -> Fut,
504    Fut: Future<Output = bool> + Send + 'static,
505{
506    let started = Instant::now();
507    let mut delay = WAIT_FOR_MIN_MILLIS;
508    while !f().await {
509        sleep(Duration::from_millis(delay)).await;
510        if started.elapsed() > duration {
511            panic!("Timed out while waiting");
512        }
513        // quadratic backoff up to a maximum delay
514        delay = (delay * 2).min(WAIT_FOR_MAX_MILLIS);
515    }
516}
517
518// Wait for 5 seconds
519pub async fn wait_for<F, Fut>(f: F)
520where
521    F: FnMut() -> Fut,
522    Fut: Future<Output = bool> + Send + 'static,
523{
524    wait_for_duration(f, Duration::from_secs(WAIT_FOR_SECS)).await
525}
526
527// Wait (for 5 secs) for a TCP socket to be reachable
528pub async fn wait_for_tcp<A>(addr: A)
529where
530    A: ToSocketAddrs + Clone + Send + 'static,
531{
532    wait_for(move || {
533        let addr = addr.clone();
534        async move { TcpStream::connect(addr).await.is_ok() }
535    })
536    .await
537}
538
539// Allows specifying a custom duration to wait for a TCP socket to be reachable
540pub async fn wait_for_tcp_duration(addr: SocketAddr, duration: Duration) {
541    wait_for_duration(
542        || async move { TcpStream::connect(addr).await.is_ok() },
543        duration,
544    )
545    .await
546}
547
548pub async fn wait_for_atomic_usize<T, F>(value: T, unblock: F)
549where
550    T: AsRef<AtomicUsize>,
551    F: Fn(usize) -> bool,
552{
553    let value = value.as_ref();
554    wait_for(|| ready(unblock(value.load(Ordering::SeqCst)))).await
555}
556
557pub async fn wait_for_atomic_usize_timeout_ms<T, F>(value: T, unblock: F, timeout_ms: u64)
558where
559    T: AsRef<AtomicUsize>,
560    F: Fn(usize) -> bool,
561{
562    let value = value.as_ref();
563    wait_for_duration(
564        || ready(unblock(value.load(Ordering::SeqCst))),
565        Duration::from_millis(timeout_ms),
566    )
567    .await
568}
569
570// Retries a func every `retry` duration until given an Ok(T); panics after `until` elapses
571pub async fn retry_until<'a, F, Fut, T, E>(mut f: F, retry: Duration, until: Duration) -> T
572where
573    F: FnMut() -> Fut,
574    Fut: Future<Output = Result<T, E>> + Send + 'a,
575{
576    let started = Instant::now();
577    while started.elapsed() < until {
578        match f().await {
579            Ok(res) => return res,
580            Err(_) => tokio::time::sleep(retry).await,
581        }
582    }
583    panic!("Timeout")
584}
585
586pub struct CountReceiver<T> {
587    count: Arc<AtomicUsize>,
588    trigger: Option<oneshot::Sender<()>>,
589    connected: Option<oneshot::Receiver<()>>,
590    handle: JoinHandle<Vec<T>>,
591}
592
593impl<T: Send + 'static> CountReceiver<T> {
594    pub fn count(&self) -> usize {
595        self.count.load(Ordering::Relaxed)
596    }
597
598    /// Succeeds once first connection has been made.
599    pub async fn connected(&mut self) {
600        if let Some(tripwire) = self.connected.take() {
601            tripwire.await.unwrap();
602        }
603    }
604
605    fn new<F, Fut>(make_fut: F) -> CountReceiver<T>
606    where
607        F: FnOnce(Arc<AtomicUsize>, oneshot::Receiver<()>, oneshot::Sender<()>) -> Fut,
608        Fut: Future<Output = Vec<T>> + Send + 'static,
609    {
610        let count = Arc::new(AtomicUsize::new(0));
611        let (trigger, tripwire) = oneshot::channel();
612        let (trigger_connected, connected) = oneshot::channel();
613
614        CountReceiver {
615            count: Arc::clone(&count),
616            trigger: Some(trigger),
617            connected: Some(connected),
618            handle: tokio::spawn(make_fut(count, tripwire, trigger_connected)),
619        }
620    }
621
622    pub fn receive_items_stream<S, F, Fut>(make_stream: F) -> CountReceiver<T>
623    where
624        S: Stream<Item = T> + Send + 'static,
625        F: FnOnce(oneshot::Receiver<()>, oneshot::Sender<()>) -> Fut + Send + 'static,
626        Fut: Future<Output = S> + Send + 'static,
627    {
628        CountReceiver::new(|count, tripwire, connected| async move {
629            let stream = make_stream(tripwire, connected).await;
630            stream
631                .inspect(move |_| {
632                    count.fetch_add(1, Ordering::Relaxed);
633                })
634                .collect::<Vec<T>>()
635                .await
636        })
637    }
638}
639
640impl<T> Future for CountReceiver<T> {
641    type Output = Vec<T>;
642
643    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
644        let this = self.get_mut();
645        if let Some(trigger) = this.trigger.take() {
646            _ = trigger.send(());
647        }
648
649        let result = ready!(this.handle.poll_unpin(cx));
650        Poll::Ready(result.unwrap())
651    }
652}
653
654impl CountReceiver<String> {
655    pub fn receive_lines(addr: SocketAddr) -> CountReceiver<String> {
656        CountReceiver::new(|count, tripwire, connected| async move {
657            let listener = TcpListener::bind(addr).await.unwrap();
658            CountReceiver::receive_lines_stream(
659                TcpListenerStream::new(listener),
660                count,
661                tripwire,
662                Some(connected),
663            )
664            .await
665        })
666    }
667
668    #[cfg(unix)]
669    pub fn receive_lines_unix<P>(path: P) -> CountReceiver<String>
670    where
671        P: AsRef<Path> + Send + 'static,
672    {
673        CountReceiver::new(|count, tripwire, connected| async move {
674            let listener = tokio::net::UnixListener::bind(path).unwrap();
675            CountReceiver::receive_lines_stream(
676                UnixListenerStream::new(listener),
677                count,
678                tripwire,
679                Some(connected),
680            )
681            .await
682        })
683    }
684
685    async fn receive_lines_stream<S, T>(
686        stream: S,
687        count: Arc<AtomicUsize>,
688        tripwire: oneshot::Receiver<()>,
689        mut connected: Option<oneshot::Sender<()>>,
690    ) -> Vec<String>
691    where
692        S: Stream<Item = IoResult<T>>,
693        T: AsyncWrite + AsyncRead,
694    {
695        stream
696            .take_until(tripwire)
697            .map_ok(|socket| FramedRead::new(socket, LinesCodec::new()))
698            .map(|x| {
699                connected.take().map(|trigger| trigger.send(()));
700                x.unwrap()
701            })
702            .flatten()
703            .map(|x| x.unwrap())
704            .inspect(move |_| {
705                count.fetch_add(1, Ordering::Relaxed);
706            })
707            .collect::<Vec<String>>()
708            .await
709    }
710}
711
712impl CountReceiver<Event> {
713    pub fn receive_events<S>(stream: S) -> CountReceiver<Event>
714    where
715        S: Stream<Item = Event> + Send + 'static,
716    {
717        CountReceiver::new(|count, tripwire, connected| async move {
718            connected.send(()).unwrap();
719            stream
720                .take_until(tripwire)
721                .inspect(move |_| {
722                    count.fetch_add(1, Ordering::Relaxed);
723                })
724                .collect::<Vec<Event>>()
725                .await
726        })
727    }
728}
729
730pub async fn start_topology(
731    mut config: Config,
732    require_healthy: impl Into<Option<bool>>,
733) -> (RunningTopology, ShutdownErrorReceiver) {
734    config.healthchecks.set_require_healthy(require_healthy);
735    RunningTopology::start_init_validated(config, Default::default())
736        .await
737        .unwrap()
738}
739
740/// Collect the first `n` events from a stream while a future is spawned
741/// in the background. This is used for tests where the collect has to
742/// happen concurrent with the sending process (ie the stream is
743/// handling finalization, which is required for the future to receive
744/// an acknowledgement).
745pub async fn spawn_collect_n<F, S>(future: F, stream: S, n: usize) -> Vec<Event>
746where
747    F: Future<Output = ()> + Send + 'static,
748    S: Stream<Item = Event>,
749{
750    // TODO: Switch to using `select!` so that we can drive `future` to completion while also driving `collect_n`,
751    // such that if `future` panics, we break out and don't continue driving `collect_n`. In most cases, `future`
752    // completing successfully is what actually drives events into `stream`, so continuing to wait for all N events when
753    // the catalyst has failed is.... almost never the desired behavior.
754    let sender = tokio::spawn(future);
755    let events = collect_n(stream, n).await;
756    sender.await.expect("Failed to send data");
757    events
758}
759
760/// Collect all the ready events from a stream after spawning a future
761/// in the background and letting it run for a given interval. This is
762/// used for tests where the collect has to happen concurrent with the
763/// sending process (ie the stream is handling finalization, which is
764/// required for the future to receive an acknowledgement).
765pub async fn spawn_collect_ready<F, S>(future: F, stream: S, sleep: u64) -> Vec<Event>
766where
767    F: Future<Output = ()> + Send + 'static,
768    S: Stream<Item = Event> + Unpin,
769{
770    let sender = tokio::spawn(future);
771    tokio::time::sleep(Duration::from_secs(sleep)).await;
772    let events = collect_ready(stream).await;
773    sender.await.expect("Failed to send data");
774    events
775}
776
777#[cfg(test)]
778mod tests {
779    use std::{
780        sync::{Arc, RwLock},
781        time::Duration,
782    };
783
784    use super::retry_until;
785
786    // helper which errors the first 3x, and succeeds on the 4th
787    async fn retry_until_helper(count: Arc<RwLock<i32>>) -> Result<(), ()> {
788        if *count.read().unwrap() < 3 {
789            let mut c = count.write().unwrap();
790            *c += 1;
791            return Err(());
792        }
793        Ok(())
794    }
795
796    #[tokio::test]
797    async fn retry_until_before_timeout() {
798        let count = Arc::new(RwLock::new(0));
799        let func = || {
800            let count = Arc::clone(&count);
801            retry_until_helper(count)
802        };
803
804        retry_until(func, Duration::from_millis(10), Duration::from_secs(1)).await;
805    }
806}