aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorUlf Lilleengen <[email protected]>2022-08-09 14:43:55 +0200
committerUlf Lilleengen <[email protected]>2022-08-09 14:43:55 +0200
commit80c1551153b06a14e5dc475f6fbd945db06b8117 (patch)
treeeafa63d946c9625ceebadb99bd864c160e040336
parent18671b94ba173d6b5c2d2ec5e3569e39a03b61bb (diff)
Wrap buffers in a single state type
-rw-r--r--embassy-net/src/tcp.rs53
-rw-r--r--examples/stm32h7/Cargo.toml1
2 files changed, 31 insertions, 23 deletions
diff --git a/embassy-net/src/tcp.rs b/embassy-net/src/tcp.rs
index 96a6dfe28..814e7ab63 100644
--- a/embassy-net/src/tcp.rs
+++ b/embassy-net/src/tcp.rs
@@ -339,15 +339,16 @@ pub mod client {
339 339
340 use super::*; 340 use super::*;
341 341
342 /// TCP client capable of creating up to N multiple connections with tx and rx buffers according to TX_SZ and RX_SZ.
342 pub struct TcpClient<'d, D: Device, const N: usize, const TX_SZ: usize = 1024, const RX_SZ: usize = 1024> { 343 pub struct TcpClient<'d, D: Device, const N: usize, const TX_SZ: usize = 1024, const RX_SZ: usize = 1024> {
343 stack: &'d Stack<D>, 344 stack: &'d Stack<D>,
344 tx: &'d BufferPool<TX_SZ, N>, 345 state: &'d TcpClientState<N, TX_SZ, RX_SZ>,
345 rx: &'d BufferPool<RX_SZ, N>,
346 } 346 }
347 347
348 impl<'d, D: Device, const N: usize, const TX_SZ: usize, const RX_SZ: usize> TcpClient<'d, D, N, TX_SZ, RX_SZ> { 348 impl<'d, D: Device, const N: usize, const TX_SZ: usize, const RX_SZ: usize> TcpClient<'d, D, N, TX_SZ, RX_SZ> {
349 pub fn new(stack: &'d Stack<D>, tx: &'d BufferPool<TX_SZ, N>, rx: &'d BufferPool<RX_SZ, N>) -> Self { 349 /// Create a new TcpClient
350 Self { stack, tx, rx } 350 pub fn new(stack: &'d Stack<D>, state: &'d TcpClientState<N, TX_SZ, RX_SZ>) -> Self {
351 Self { stack, state }
351 } 352 }
352 } 353 }
353 354
@@ -370,7 +371,7 @@ pub mod client {
370 IpAddr::V6(_) => panic!("ipv6 support not enabled"), 371 IpAddr::V6(_) => panic!("ipv6 support not enabled"),
371 }; 372 };
372 let remote_endpoint = (addr, remote.port()); 373 let remote_endpoint = (addr, remote.port());
373 let mut socket = TcpConnection::new(&self.stack, self.tx, self.rx)?; 374 let mut socket = TcpConnection::new(&self.stack, self.state)?;
374 socket 375 socket
375 .socket 376 .socket
376 .connect(remote_endpoint) 377 .connect(remote_endpoint)
@@ -383,26 +384,20 @@ pub mod client {
383 384
384 pub struct TcpConnection<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize> { 385 pub struct TcpConnection<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize> {
385 socket: TcpSocket<'d>, 386 socket: TcpSocket<'d>,
386 tx: &'d BufferPool<TX_SZ, N>, 387 state: &'d TcpClientState<N, TX_SZ, RX_SZ>,
387 rx: &'d BufferPool<RX_SZ, N>, 388 bufs: NonNull<([u8; TX_SZ], [u8; RX_SZ])>,
388 txb: NonNull<[u8; TX_SZ]>,
389 rxb: NonNull<[u8; RX_SZ]>,
390 } 389 }
391 390
392 impl<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize> TcpConnection<'d, N, TX_SZ, RX_SZ> { 391 impl<'d, const N: usize, const TX_SZ: usize, const RX_SZ: usize> TcpConnection<'d, N, TX_SZ, RX_SZ> {
393 fn new<D: Device>( 392 fn new<D: Device>(
394 stack: &'d Stack<D>, 393 stack: &'d Stack<D>,
395 tx: &'d BufferPool<TX_SZ, N>, 394 state: &'d TcpClientState<N, TX_SZ, RX_SZ>,
396 rx: &'d BufferPool<RX_SZ, N>,
397 ) -> Result<Self, Error> { 395 ) -> Result<Self, Error> {
398 let mut txb = tx.alloc().ok_or(Error::ConnectionReset)?; 396 let mut bufs = state.pool.alloc().ok_or(Error::ConnectionReset)?;
399 let mut rxb = rx.alloc().ok_or(Error::ConnectionReset)?;
400 Ok(Self { 397 Ok(Self {
401 socket: unsafe { TcpSocket::new(stack, rxb.as_mut(), txb.as_mut()) }, 398 socket: unsafe { TcpSocket::new(stack, &mut bufs.as_mut().0, &mut bufs.as_mut().1) },
402 tx, 399 state,
403 rx, 400 bufs,
404 txb,
405 rxb,
406 }) 401 })
407 } 402 }
408 } 403 }
@@ -411,8 +406,7 @@ pub mod client {
411 fn drop(&mut self) { 406 fn drop(&mut self) {
412 unsafe { 407 unsafe {
413 self.socket.close(); 408 self.socket.close();
414 self.rx.free(self.rxb); 409 self.state.pool.free(self.bufs);
415 self.tx.free(self.txb);
416 } 410 }
417 } 411 }
418 } 412 }
@@ -455,9 +449,22 @@ pub mod client {
455 } 449 }
456 } 450 }
457 451
458 pub type BufferPool<const BUFSZ: usize, const N: usize> = Pool<[u8; BUFSZ], N>; 452 /// State for TcpClient
453 pub struct TcpClientState<const N: usize, const TX_SZ: usize, const RX_SZ: usize> {
454 pool: Pool<([u8; TX_SZ], [u8; RX_SZ]), N>,
455 }
459 456
460 pub struct Pool<T, const N: usize> { 457 impl<const N: usize, const TX_SZ: usize, const RX_SZ: usize> TcpClientState<N, TX_SZ, RX_SZ> {
458 pub const fn new() -> Self {
459 Self {
460 pool: Pool::new()
461 }
462 }
463 }
464
465 unsafe impl<const N: usize, const TX_SZ: usize, const RX_SZ: usize> Sync for TcpClientState<N, TX_SZ, RX_SZ> {}
466
467 struct Pool<T, const N: usize> {
461 used: [AtomicBool; N], 468 used: [AtomicBool; N],
462 data: [UnsafeCell<MaybeUninit<T>>; N], 469 data: [UnsafeCell<MaybeUninit<T>>; N],
463 } 470 }
@@ -466,7 +473,7 @@ pub mod client {
466 const VALUE: AtomicBool = AtomicBool::new(false); 473 const VALUE: AtomicBool = AtomicBool::new(false);
467 const UNINIT: UnsafeCell<MaybeUninit<T>> = UnsafeCell::new(MaybeUninit::uninit()); 474 const UNINIT: UnsafeCell<MaybeUninit<T>> = UnsafeCell::new(MaybeUninit::uninit());
468 475
469 pub const fn new() -> Self { 476 const fn new() -> Self {
470 Self { 477 Self {
471 used: [Self::VALUE; N], 478 used: [Self::VALUE; N],
472 data: [Self::UNINIT; N], 479 data: [Self::UNINIT; N],
diff --git a/examples/stm32h7/Cargo.toml b/examples/stm32h7/Cargo.toml
index 07b7e4931..896046759 100644
--- a/examples/stm32h7/Cargo.toml
+++ b/examples/stm32h7/Cargo.toml
@@ -18,6 +18,7 @@ cortex-m-rt = "0.7.0"
18embedded-hal = "0.2.6" 18embedded-hal = "0.2.6"
19embedded-hal-1 = { package = "embedded-hal", version = "1.0.0-alpha.8" } 19embedded-hal-1 = { package = "embedded-hal", version = "1.0.0-alpha.8" }
20embedded-hal-async = { version = "0.1.0-alpha.1" } 20embedded-hal-async = { version = "0.1.0-alpha.1" }
21embedded-nal-async = "0.2.0"
21panic-probe = { version = "0.3", features = ["print-defmt"] } 22panic-probe = { version = "0.3", features = ["print-defmt"] }
22futures = { version = "0.3.17", default-features = false, features = ["async-await"] } 23futures = { version = "0.3.17", default-features = false, features = ["async-await"] }
23heapless = { version = "0.7.5", default-features = false } 24heapless = { version = "0.7.5", default-features = false }