Skip to main content

hydro_lang/deploy/
deploy_runtime_containerized.rs

1#![allow(
2    unused,
3    reason = "unused in trybuild but the __staged version is needed"
4)]
5#![allow(missing_docs, reason = "used internally")]
6
7use std::collections::HashMap;
8use std::future::Future;
9use std::net::SocketAddr;
10use std::ops::{Deref, DerefMut};
11use std::pin::Pin;
12use std::sync::Arc;
13use std::task::{Context, Poll};
14use std::time::Duration;
15
16use bytes::BytesMut;
17use futures::{FutureExt, Sink, SinkExt, Stream, StreamExt};
18use proc_macro2::Span;
19use sinktools::demux_map_lazy::LazyDemuxSink;
20use sinktools::lazy::{LazySink, LazySource};
21use sinktools::lazy_sink_source::LazySinkSource;
22use stageleft::runtime_support::{
23    FreeVariableWithContext, FreeVariableWithContextWithProps, QuoteTokens,
24};
25use stageleft::{QuotedWithContext, q};
26use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
27use tokio::net::{TcpListener, TcpStream};
28use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
29use tracing::{debug, instrument, warn};
30
31use crate::location::dynamic::LocationId;
32use crate::location::member_id::TaglessMemberId;
33use crate::location::{LocationKey, MemberId, MembershipEvent};
34
35/// The single well-known port that every node listens on.
36pub const CHANNEL_MUX_PORT: u16 = 10000;
37
38/// Magic constant embedded in every [`ChannelMagic`] header.
39pub const CHANNEL_MAGIC: u64 = 0x4859_4452_4f5f_4348;
40
41/// Magic header sent as the very first frame of every channel handshake.
42///
43/// This is a fixed value that never changes across versions, used to confirm
44/// both sides are speaking the same protocol family before anything else.
45#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
46pub struct ChannelMagic {
47    pub magic: u64,
48}
49
50/// Current protocol version for the channel handshake.
51pub const CHANNEL_PROTOCOL_VERSION: u64 = 1;
52
53/// Protocol version sent as the second frame, after [`ChannelMagic`].
54///
55/// Incremented when the handshake format changes. The receiver checks this
56/// to decide how to deserialize the subsequent [`ChannelHandshake`] frame.
57#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
58pub struct ChannelProtocolVersion {
59    pub version: u64,
60}
61
62/// Handshake message sent by the connecting side to identify the channel.
63///
64/// The receiver reads the third frame (after [`ChannelMagic`] and
65/// [`ChannelProtocolVersion`]) to know which logical channel the connection
66/// belongs to, and optionally which cluster member is connecting.
67/// cluster member is connecting.
68#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
69pub struct ChannelHandshake {
70    /// The logical channel name for this connection.
71    pub channel_name: String,
72    /// If the sender is a cluster member, this is its identifier
73    /// (container name for Docker, task ID for ECS, etc.).
74    /// `None` for process-to-process connections.
75    pub sender_id: Option<String>,
76}
77
78/// A dispatched channel connection: optional sender ID and the read stream.
79type MuxConnection = (
80    Option<String>,
81    FramedRead<OwnedReadHalf, LengthDelimitedCodec>,
82);
83
84/// A shared accept loop that listens on a single port and dispatches
85/// incoming connections to the right consumer based on the channel name
86/// sent in the handshake.
87///
88/// Each node creates one of these at startup. Individual channels register
89/// themselves and receive their connection via a mpsc channel.
90pub struct ChannelMux {
91    /// Map from channel name to a sender that delivers accepted connections.
92    channels: std::sync::Mutex<HashMap<String, tokio::sync::mpsc::UnboundedSender<MuxConnection>>>,
93}
94
95impl Default for ChannelMux {
96    fn default() -> Self {
97        Self::new()
98    }
99}
100
101impl ChannelMux {
102    pub fn new() -> Self {
103        Self {
104            channels: std::sync::Mutex::new(HashMap::new()),
105        }
106    }
107
108    pub fn register(
109        &self,
110        channel_name: String,
111    ) -> tokio::sync::mpsc::UnboundedReceiver<MuxConnection> {
112        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
113        let mut channels = self.channels.lock().unwrap();
114        channels.insert(channel_name, tx);
115        rx
116    }
117
118    pub async fn run_accept_loop(self: Arc<Self>, listener: TcpListener) {
119        loop {
120            let (stream, peer) = match listener.accept().await {
121                Ok(v) => v,
122                Err(e) => {
123                    warn!(name: "accept_error", error = %e);
124                    continue;
125                }
126            };
127            debug!(name: "mux_accepting", ?peer);
128
129            let mux = self.clone();
130            tokio::spawn(async move {
131                let (rx, _tx) = stream.into_split();
132                let mut source = FramedRead::new(rx, LengthDelimitedCodec::new());
133
134                let magic_frame = match source.next().await {
135                    Some(Ok(frame)) => frame,
136                    _ => {
137                        warn!(name: "magic_failed", ?peer, "no magic frame");
138                        return;
139                    }
140                };
141
142                let magic: ChannelMagic = match bincode::deserialize(&magic_frame) {
143                    Ok(m) => m,
144                    Err(e) => {
145                        warn!(name: "magic_deserialize_failed", ?peer, error = %e);
146                        return;
147                    }
148                };
149
150                if magic.magic != CHANNEL_MAGIC {
151                    warn!(name: "bad_magic", ?peer, magic = magic.magic, expected = CHANNEL_MAGIC);
152                    return;
153                }
154
155                let version_frame = match source.next().await {
156                    Some(Ok(frame)) => frame,
157                    _ => {
158                        warn!(name: "version_failed", ?peer, "no version frame");
159                        return;
160                    }
161                };
162
163                let version: ChannelProtocolVersion = match bincode::deserialize(&version_frame) {
164                    Ok(v) => v,
165                    Err(e) => {
166                        warn!(name: "version_deserialize_failed", ?peer, error = %e);
167                        return;
168                    }
169                };
170
171                if version.version != CHANNEL_PROTOCOL_VERSION {
172                    warn!(name: "version_mismatch", ?peer, version = version.version, expected = CHANNEL_PROTOCOL_VERSION);
173                    return;
174                }
175
176                let handshake_frame = match source.next().await {
177                    Some(Ok(frame)) => frame,
178                    _ => {
179                        warn!(name: "handshake_failed", ?peer, "no handshake frame");
180                        return;
181                    }
182                };
183
184                let handshake: ChannelHandshake = match bincode::deserialize(&handshake_frame) {
185                    Ok(h) => h,
186                    Err(e) => {
187                        warn!(name: "handshake_deserialize_failed", ?peer, error = %e);
188                        return;
189                    }
190                };
191
192                debug!(name: "handshake_received", ?peer, ?handshake);
193
194                let channels = mux.channels.lock().unwrap();
195                if let Some(tx_chan) = channels.get(&handshake.channel_name) {
196                    let _ = tx_chan.send((handshake.sender_id, source));
197                } else {
198                    warn!(
199                        name: "unknown_channel",
200                        channel_name = %handshake.channel_name,
201                        ?peer,
202                        "no registered consumer for channel"
203                    );
204                }
205            });
206        }
207    }
208}
209
210/// Get or initialize the global ChannelMux for this process.
211///
212/// The first call creates the TcpListener and spawns the accept loop.
213/// Subsequent calls return the same `Arc<ChannelMux>`.
214pub fn get_or_init_channel_mux() -> Arc<ChannelMux> {
215    use std::sync::OnceLock;
216    static MUX: OnceLock<Arc<ChannelMux>> = OnceLock::new();
217
218    MUX.get_or_init(|| {
219        let mux = Arc::new(ChannelMux::new());
220        let mux_clone = mux.clone();
221
222        // Spawn the accept loop in a background task.
223        // We use tokio::spawn which requires a runtime to be active.
224        tokio::spawn(async move {
225            let bind_addr = format!("0.0.0.0:{}", CHANNEL_MUX_PORT);
226            debug!(name: "mux_listening", %bind_addr);
227            let listener = TcpListener::bind(&bind_addr)
228                .await
229                .expect("failed to bind channel mux listener");
230            mux_clone.run_accept_loop(listener).await;
231        });
232
233        mux
234    })
235    .clone()
236}
237
238/// Sends a [`ChannelMagic`], then a [`ChannelProtocolVersion`], then a
239/// [`ChannelHandshake`] as three separate frames over the given sink.
240pub async fn send_handshake(
241    sink: &mut FramedWrite<TcpStream, LengthDelimitedCodec>,
242    channel_name: &str,
243    sender_id: Option<&str>,
244) -> Result<(), std::io::Error> {
245    let magic = ChannelMagic {
246        magic: CHANNEL_MAGIC,
247    };
248    sink.send(bytes::Bytes::from(bincode::serialize(&magic).unwrap()))
249        .await?;
250
251    let version = ChannelProtocolVersion {
252        version: CHANNEL_PROTOCOL_VERSION,
253    };
254    sink.send(bytes::Bytes::from(bincode::serialize(&version).unwrap()))
255        .await?;
256
257    let handshake = ChannelHandshake {
258        channel_name: channel_name.to_owned(),
259        sender_id: sender_id.map(|s| s.to_owned()),
260    };
261    sink.send(bytes::Bytes::from(bincode::serialize(&handshake).unwrap()))
262        .await?;
263    Ok(())
264}
265
266pub fn deploy_containerized_o2o(target: &str, channel_name: &str) -> (syn::Expr, syn::Expr) {
267    (
268        q!(LazySink::<_, _, _, bytes::Bytes>::new(move || Box::pin(
269            async move {
270                let channel_name = channel_name;
271                let target = format!("{}:{}", target, self::CHANNEL_MUX_PORT);
272                debug!(name: "connecting", %target, %channel_name);
273
274                let stream = TcpStream::connect(&target).await?;
275                let mut sink = FramedWrite::new(stream, LengthDelimitedCodec::new());
276
277                self::send_handshake(&mut sink, channel_name, None).await?;
278
279                Result::<_, std::io::Error>::Ok(sink)
280            }
281        )))
282        .splice_untyped_ctx(&()),
283        q!(LazySource::new(move || Box::pin(async move {
284            let channel_name = channel_name;
285            let mux = self::get_or_init_channel_mux();
286            let mut rx = mux.register(channel_name.to_owned());
287
288            let (_sender_id, source) = rx.recv().await.ok_or_else(|| {
289                std::io::Error::new(std::io::ErrorKind::ConnectionReset, "channel mux closed")
290            })?;
291
292            debug!(name: "o2o_channel_connected", %channel_name);
293
294            Result::<_, std::io::Error>::Ok(source)
295        })))
296        .splice_untyped_ctx(&()),
297    )
298}
299
300pub fn deploy_containerized_o2m(channel_name: &str) -> (syn::Expr, syn::Expr) {
301    (
302        q!(sinktools::demux_map_lazy::<_, _, _, _>(
303            move |key: &TaglessMemberId| {
304                let key = key.clone();
305                let channel_name = channel_name.to_owned();
306
307                LazySink::<_, _, _, bytes::Bytes>::new(move || {
308                    Box::pin(async move {
309                        let target =
310                            format!("{}:{}", key.get_container_name(), self::CHANNEL_MUX_PORT);
311                        debug!(name: "connecting", %target, channel_name = %channel_name);
312
313                        let stream = TcpStream::connect(&target).await?;
314                        let mut sink = FramedWrite::new(stream, LengthDelimitedCodec::new());
315
316                        self::send_handshake(&mut sink, &channel_name, None).await?;
317
318                        Result::<_, std::io::Error>::Ok(sink)
319                    })
320                })
321            }
322        ))
323        .splice_untyped_ctx(&()),
324        q!(LazySource::new(move || Box::pin(async move {
325            let channel_name = channel_name;
326            let mux = self::get_or_init_channel_mux();
327            let mut rx = mux.register(channel_name.to_owned());
328
329            let (_sender_id, source) = rx.recv().await.ok_or_else(|| {
330                std::io::Error::new(std::io::ErrorKind::ConnectionReset, "channel mux closed")
331            })?;
332
333            debug!(name: "o2m_channel_connected", %channel_name);
334
335            Result::<_, std::io::Error>::Ok(source)
336        })))
337        .splice_untyped_ctx(&()),
338    )
339}
340
341pub fn deploy_containerized_m2o(target_host: &str, channel_name: &str) -> (syn::Expr, syn::Expr) {
342    (
343        q!(LazySink::<_, _, _, bytes::Bytes>::new(move || {
344            Box::pin(async move {
345                let channel_name = channel_name;
346                let target = format!("{}:{}", target_host, self::CHANNEL_MUX_PORT);
347                debug!(name: "connecting", %target, %channel_name);
348
349                let stream = TcpStream::connect(&target).await?;
350                let mut sink = FramedWrite::new(stream, LengthDelimitedCodec::new());
351
352                let container_name = std::env::var("CONTAINER_NAME").unwrap();
353                self::send_handshake(&mut sink, channel_name, Some(&container_name)).await?;
354
355                Result::<_, std::io::Error>::Ok(sink)
356            })
357        }))
358        .splice_untyped_ctx(&()),
359        q!(LazySource::new(move || Box::pin(async move {
360            let channel_name = channel_name;
361            let mux = self::get_or_init_channel_mux();
362            let mut rx = mux.register(channel_name.to_owned());
363
364            Result::<_, std::io::Error>::Ok(
365                futures::stream::unfold(rx, |mut rx| {
366                    Box::pin(async move {
367                        let (sender_id, source) = rx.recv().await?;
368                        let from = sender_id.expect("m2o sender must provide container name");
369
370                        debug!(name: "m2o_channel_connected", %from);
371
372                        Some((
373                            source.map(move |v| {
374                                v.map(|v| (TaglessMemberId::from_container_name(from.clone()), v))
375                            }),
376                            rx,
377                        ))
378                    })
379                })
380                .flatten_unordered(None),
381            )
382        })))
383        .splice_untyped_ctx(&()),
384    )
385}
386
387pub fn deploy_containerized_m2m(channel_name: &str) -> (syn::Expr, syn::Expr) {
388    (
389        q!(sinktools::demux_map_lazy::<_, _, _, _>(
390            move |key: &TaglessMemberId| {
391                let key = key.clone();
392                let channel_name = channel_name.to_owned();
393
394                LazySink::<_, _, _, bytes::Bytes>::new(move || {
395                    Box::pin(async move {
396                        let target =
397                            format!("{}:{}", key.get_container_name(), self::CHANNEL_MUX_PORT);
398                        debug!(name: "connecting", %target, channel_name = %channel_name);
399
400                        let stream = TcpStream::connect(&target).await?;
401                        let mut sink = FramedWrite::new(stream, LengthDelimitedCodec::new());
402
403                        let container_name = std::env::var("CONTAINER_NAME").unwrap();
404                        self::send_handshake(&mut sink, &channel_name, Some(&container_name))
405                            .await?;
406
407                        Result::<_, std::io::Error>::Ok(sink)
408                    })
409                })
410            }
411        ))
412        .splice_untyped_ctx(&()),
413        q!(LazySource::new(move || Box::pin(async move {
414            let channel_name = channel_name;
415            let mux = self::get_or_init_channel_mux();
416            let mut rx = mux.register(channel_name.to_owned());
417
418            Result::<_, std::io::Error>::Ok(
419                futures::stream::unfold(rx, |mut rx| {
420                    Box::pin(async move {
421                        let (sender_id, source) = rx.recv().await?;
422                        let from = sender_id.expect("m2m sender must provide container name");
423
424                        debug!(name: "m2m_channel_connected", %from);
425
426                        Some((
427                            source.map(move |v| {
428                                v.map(|v| (TaglessMemberId::from_container_name(from.clone()), v))
429                            }),
430                            rx,
431                        ))
432                    })
433                })
434                .flatten_unordered(None),
435            )
436        })))
437        .splice_untyped_ctx(&()),
438    )
439}
440
441pub struct SocketIdent {
442    pub socket_ident: syn::Ident,
443}
444
445impl<Ctx> FreeVariableWithContextWithProps<Ctx, ()> for SocketIdent {
446    type O = TcpListener;
447
448    fn to_tokens(self, _ctx: &Ctx) -> (QuoteTokens, ())
449    where
450        Self: Sized,
451    {
452        let ident = self.socket_ident;
453
454        (
455            QuoteTokens {
456                prelude: None,
457                expr: Some(quote::quote! { #ident }),
458            },
459            (),
460        )
461    }
462}
463
464pub fn deploy_containerized_external_sink_source_ident(socket_ident: syn::Ident) -> syn::Expr {
465    let socket_ident = SocketIdent { socket_ident };
466
467    q!(LazySinkSource::<
468        _,
469        FramedRead<OwnedReadHalf, LengthDelimitedCodec>,
470        FramedWrite<OwnedWriteHalf, LengthDelimitedCodec>,
471        bytes::Bytes,
472        std::io::Error,
473    >::new(async move {
474        let (stream, peer) = socket_ident.accept().await?;
475        debug!(name: "external accepting", ?peer);
476        let (rx, tx) = stream.into_split();
477
478        let fr = FramedRead::new(rx, LengthDelimitedCodec::new());
479        let fw = FramedWrite::new(tx, LengthDelimitedCodec::new());
480
481        Result::<_, std::io::Error>::Ok((fr, fw))
482    },))
483    .splice_untyped_ctx(&())
484}
485
486pub fn cluster_ids<'a>() -> impl QuotedWithContext<'a, &'a [TaglessMemberId], ()> + Clone {
487    // unimplemented!(); // this is unused.
488
489    // This is a dummy piece of code, since clusters are dynamic when containerized.
490    q!(Box::leak(Box::new([TaglessMemberId::from_container_name(
491        "INVALID CONTAINER NAME cluster_ids"
492    )]))
493    .as_slice())
494}
495
496#[cfg(feature = "docker_runtime")]
497pub fn cluster_self_id<'a>() -> impl QuotedWithContext<'a, TaglessMemberId, ()> + Clone + 'a {
498    q!(TaglessMemberId::from_container_name(
499        std::env::var("CONTAINER_NAME").unwrap()
500    ))
501}
502
503#[cfg(feature = "docker_runtime")]
504pub fn cluster_membership_stream<'a>(
505    location_id: &LocationId,
506) -> impl QuotedWithContext<'a, Box<dyn Stream<Item = (TaglessMemberId, MembershipEvent)> + Unpin>, ()>
507{
508    let key = location_id.key();
509
510    q!(Box::new(self::docker_membership_stream(
511        std::env::var("DEPLOYMENT_INSTANCE").unwrap(),
512        key
513    ))
514        as Box<
515            dyn Stream<Item = (TaglessMemberId, MembershipEvent)> + Unpin,
516        >)
517}
518
519#[cfg(feature = "docker_runtime")]
520// There's a risk of race conditions here since all the containers will be starting up at the same time.
521// So we need to start listening for events and the take a snapshot of currently running containers, since they may have already started up before we started listening to events.
522// Then we need to turn that into a usable stream for the consumer in this current hydro program. The way you do that is by emitting from the snapshot first, and then start emitting from the stream. Keep a hash set around to track whether a container is up or down.
523#[instrument(skip_all, fields(%deployment_instance, %location_key))]
524fn docker_membership_stream(
525    deployment_instance: String,
526    location_key: LocationKey,
527) -> impl Stream<Item = (TaglessMemberId, MembershipEvent)> + Unpin {
528    use std::collections::HashSet;
529    use std::sync::{Arc, Mutex};
530
531    use bollard::Docker;
532    use bollard::query_parameters::{EventsOptions, ListContainersOptions};
533    use tokio::sync::mpsc;
534
535    let docker = Docker::connect_with_local_defaults()
536        .unwrap()
537        .with_timeout(Duration::from_secs(1));
538
539    let (event_tx, event_rx) = mpsc::unbounded_channel::<(String, MembershipEvent)>();
540
541    // 1. Start event subscription in a spawned task
542    let events_docker = docker.clone();
543    let events_deployment_instance = deployment_instance.clone();
544    tokio::spawn(async move {
545        let mut filters = HashMap::new();
546        filters.insert("type".to_owned(), vec!["container".to_owned()]);
547        filters.insert(
548            "event".to_owned(),
549            vec!["start".to_owned(), "die".to_owned()],
550        );
551        let event_options = Some(EventsOptions {
552            filters: Some(filters),
553            ..Default::default()
554        });
555
556        let mut events = events_docker.events(event_options);
557        while let Some(event) = events.next().await {
558            if let Some((name, membership_event)) = event.ok().and_then(|e| {
559                let name = e
560                    .actor
561                    .as_ref()
562                    .and_then(|a| a.attributes.as_ref())
563                    .and_then(|attrs| attrs.get("name"))
564                    .map(|s| &**s)?;
565
566                if name.contains(format!("{events_deployment_instance}-{location_key}").as_str()) {
567                    match e.action.as_deref() {
568                        Some("start") => Some((name.to_owned(), MembershipEvent::Joined)),
569                        Some("die") => Some((name.to_owned(), MembershipEvent::Left)),
570                        _ => None,
571                    }
572                } else {
573                    None
574                }
575            }) && event_tx.send((name, membership_event)).is_err()
576            {
577                break;
578            }
579        }
580    });
581
582    // Shared state for deduplication across snapshot and events phases
583    let seen_joined = Arc::new(Mutex::new(HashSet::<String>::new()));
584    let seen_joined_snapshot = seen_joined.clone();
585    let seen_joined_events = seen_joined;
586
587    // 2. Snapshot stream - fetch current containers and emit Joined events
588    let snapshot_stream = futures::stream::once(async move {
589        let mut filters = HashMap::new();
590        filters.insert(
591            "name".to_owned(),
592            vec![format!("{deployment_instance}-{location_key}")],
593        );
594        let options = Some(ListContainersOptions {
595            filters: Some(filters),
596            ..Default::default()
597        });
598
599        docker
600            .list_containers(options)
601            .await
602            .unwrap_or_default()
603            .iter()
604            .filter_map(|c| c.names.as_deref())
605            .filter_map(|names| names.first())
606            .map(|name| name.trim_start_matches('/'))
607            .filter(|&name| seen_joined_snapshot.lock().unwrap().insert(name.to_owned()))
608            .map(|name| (name.to_owned(), MembershipEvent::Joined))
609            .collect::<Vec<_>>()
610    })
611    .flat_map(futures::stream::iter);
612
613    // 3. Events stream - process live events with deduplication
614    let events_stream = tokio_stream::StreamExt::filter_map(
615        tokio_stream::wrappers::UnboundedReceiverStream::new(event_rx),
616        move |(name, event)| {
617            let mut seen = seen_joined_events.lock().unwrap();
618            match event {
619                MembershipEvent::Joined => {
620                    if seen.insert(name.to_owned()) {
621                        Some((name, MembershipEvent::Joined))
622                    } else {
623                        None
624                    }
625                }
626                MembershipEvent::Left => seen.take(&name).map(|name| (name, MembershipEvent::Left)),
627            }
628        },
629    );
630
631    // 4. Chain snapshot then events
632    Box::pin(
633        snapshot_stream
634            .chain(events_stream)
635            .map(|(k, v)| (TaglessMemberId::from_container_name(k), v))
636            .inspect(|(member_id, event)| debug!(name: "membership_event", ?member_id, ?event)),
637    )
638}