aboutsummaryrefslogtreecommitdiff
path: root/embassy-net/src/dns.rs
blob: dbe73776c7e23543a990b7e849ef39f653789e45 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
//! DNS client compatible with the `embedded-nal-async` traits.
//!
//! This exists only for compatibility with crates that use `embedded-nal-async`.
//! Prefer using [`Stack::dns_query`](crate::Stack::dns_query) directly if you're
//! not using `embedded-nal-async`.

use heapless::Vec;
pub use smoltcp::socket::dns::{DnsQuery, Socket};
pub(crate) use smoltcp::socket::dns::{GetQueryResultError, StartQueryError};
pub use smoltcp::wire::{DnsQueryType, IpAddress};

use crate::Stack;

/// Errors returned by DnsSocket.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum Error {
    /// Invalid name
    InvalidName,
    /// Name too long
    NameTooLong,
    /// Name lookup failed
    Failed,
}

impl From<GetQueryResultError> for Error {
    fn from(_: GetQueryResultError) -> Self {
        Self::Failed
    }
}

impl From<StartQueryError> for Error {
    fn from(e: StartQueryError) -> Self {
        match e {
            StartQueryError::NoFreeSlot => Self::Failed,
            StartQueryError::InvalidName => Self::InvalidName,
            StartQueryError::NameTooLong => Self::NameTooLong,
        }
    }
}

/// DNS client compatible with the `embedded-nal-async` traits.
///
/// This exists only for compatibility with crates that use `embedded-nal-async`.
/// Prefer using [`Stack::dns_query`](crate::Stack::dns_query) directly if you're
/// not using `embedded-nal-async`.
pub struct DnsSocket<'a> {
    stack: Stack<'a>,
}

impl<'a> DnsSocket<'a> {
    /// Create a new DNS socket using the provided stack.
    ///
    /// NOTE: If using DHCP, make sure it has reconfigured the stack to ensure the DNS servers are updated.
    pub fn new(stack: Stack<'a>) -> Self {
        Self { stack }
    }

    /// Make a query for a given name and return the corresponding IP addresses.
    pub async fn query(
        &self,
        name: &str,
        qtype: DnsQueryType,
    ) -> Result<Vec<IpAddress, { smoltcp::config::DNS_MAX_RESULT_COUNT }>, Error> {
        self.stack.dns_query(name, qtype).await
    }
}

impl<'a> embedded_nal_async::Dns for DnsSocket<'a> {
    type Error = Error;

    async fn get_host_by_name(
        &self,
        host: &str,
        addr_type: embedded_nal_async::AddrType,
    ) -> Result<core::net::IpAddr, Self::Error> {
        use core::net::IpAddr;

        use embedded_nal_async::AddrType;

        let (qtype, secondary_qtype) = match addr_type {
            AddrType::IPv4 => (DnsQueryType::A, None),
            AddrType::IPv6 => (DnsQueryType::Aaaa, None),
            AddrType::Either => {
                #[cfg(not(feature = "proto-ipv6"))]
                let v6_first = false;
                #[cfg(feature = "proto-ipv6")]
                let v6_first = self.stack.config_v6().is_some();
                match v6_first {
                    true => (DnsQueryType::Aaaa, Some(DnsQueryType::A)),
                    false => (DnsQueryType::A, Some(DnsQueryType::Aaaa)),
                }
            }
        };
        let mut addrs = self.query(host, qtype).await?;
        if addrs.is_empty() {
            if let Some(qtype) = secondary_qtype {
                addrs = self.query(host, qtype).await?
            }
        }
        if let Some(first) = addrs.get(0) {
            Ok(match first {
                #[cfg(feature = "proto-ipv4")]
                IpAddress::Ipv4(addr) => IpAddr::V4(*addr),
                #[cfg(feature = "proto-ipv6")]
                IpAddress::Ipv6(addr) => IpAddr::V6(*addr),
            })
        } else {
            Err(Error::Failed)
        }
    }

    async fn get_host_by_address(&self, _addr: core::net::IpAddr, _result: &mut [u8]) -> Result<usize, Self::Error> {
        todo!()
    }
}

fn _assert_covariant<'a, 'b: 'a>(x: DnsSocket<'b>) -> DnsSocket<'a> {
    x
}