aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--embassy-nrf/src/usb.rs20
-rw-r--r--embassy-usb-hid/src/async_lease.rs90
-rw-r--r--embassy-usb-hid/src/lib.rs143
-rw-r--r--embassy-usb/src/driver.rs8
-rw-r--r--examples/nrf/src/bin/usb_hid.rs4
5 files changed, 78 insertions, 187 deletions
diff --git a/embassy-nrf/src/usb.rs b/embassy-nrf/src/usb.rs
index df0efa511..124316a29 100644
--- a/embassy-nrf/src/usb.rs
+++ b/embassy-nrf/src/usb.rs
@@ -443,27 +443,12 @@ unsafe fn write_dma<T: Instance>(i: usize, buf: &[u8]) {
443 443
444impl<'d, T: Instance> driver::EndpointOut for Endpoint<'d, T, Out> { 444impl<'d, T: Instance> driver::EndpointOut for Endpoint<'d, T, Out> {
445 type ReadFuture<'a> = impl Future<Output = Result<usize, ReadError>> + 'a where Self: 'a; 445 type ReadFuture<'a> = impl Future<Output = Result<usize, ReadError>> + 'a where Self: 'a;
446 type DataReadyFuture<'a> = impl Future<Output = ()> + 'a where Self: 'a;
447 446
448 fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::ReadFuture<'a> { 447 fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::ReadFuture<'a> {
449 async move { 448 async move {
450 let i = self.info.addr.index(); 449 let i = self.info.addr.index();
451 assert!(i != 0); 450 assert!(i != 0);
452 451
453 self.wait_data_ready().await;
454
455 // Mark as not ready
456 READY_ENDPOINTS.fetch_and(!(1 << (i + 16)), Ordering::AcqRel);
457
458 unsafe { read_dma::<T>(i, buf) }
459 }
460 }
461
462 fn wait_data_ready<'a>(&'a mut self) -> Self::DataReadyFuture<'a> {
463 async move {
464 let i = self.info.addr.index();
465 assert!(i != 0);
466
467 // Wait until ready 452 // Wait until ready
468 poll_fn(|cx| { 453 poll_fn(|cx| {
469 EP_OUT_WAKERS[i - 1].register(cx.waker()); 454 EP_OUT_WAKERS[i - 1].register(cx.waker());
@@ -475,6 +460,11 @@ impl<'d, T: Instance> driver::EndpointOut for Endpoint<'d, T, Out> {
475 } 460 }
476 }) 461 })
477 .await; 462 .await;
463
464 // Mark as not ready
465 READY_ENDPOINTS.fetch_and(!(1 << (i + 16)), Ordering::AcqRel);
466
467 unsafe { read_dma::<T>(i, buf) }
478 } 468 }
479 } 469 }
480} 470}
diff --git a/embassy-usb-hid/src/async_lease.rs b/embassy-usb-hid/src/async_lease.rs
deleted file mode 100644
index 0971daa25..000000000
--- a/embassy-usb-hid/src/async_lease.rs
+++ /dev/null
@@ -1,90 +0,0 @@
1use core::cell::Cell;
2use core::future::Future;
3use core::task::{Poll, Waker};
4
5enum AsyncLeaseState {
6 Empty,
7 Waiting(*mut u8, usize, Waker),
8 Done(usize),
9}
10
11impl Default for AsyncLeaseState {
12 fn default() -> Self {
13 AsyncLeaseState::Empty
14 }
15}
16
17#[derive(Default)]
18pub struct AsyncLease {
19 state: Cell<AsyncLeaseState>,
20}
21
22pub struct AsyncLeaseFuture<'a> {
23 buf: &'a mut [u8],
24 state: &'a Cell<AsyncLeaseState>,
25}
26
27impl<'a> Drop for AsyncLeaseFuture<'a> {
28 fn drop(&mut self) {
29 self.state.set(AsyncLeaseState::Empty);
30 }
31}
32
33impl<'a> Future for AsyncLeaseFuture<'a> {
34 type Output = usize;
35
36 fn poll(
37 mut self: core::pin::Pin<&mut Self>,
38 cx: &mut core::task::Context<'_>,
39 ) -> Poll<Self::Output> {
40 match self.state.take() {
41 AsyncLeaseState::Done(len) => Poll::Ready(len),
42 state => {
43 if let AsyncLeaseState::Waiting(ptr, _, _) = state {
44 assert_eq!(
45 ptr,
46 self.buf.as_mut_ptr(),
47 "lend() called on a busy AsyncLease."
48 );
49 }
50
51 self.state.set(AsyncLeaseState::Waiting(
52 self.buf.as_mut_ptr(),
53 self.buf.len(),
54 cx.waker().clone(),
55 ));
56 Poll::Pending
57 }
58 }
59 }
60}
61
62pub struct AsyncLeaseNotReady {}
63
64impl AsyncLease {
65 pub fn new() -> Self {
66 Default::default()
67 }
68
69 pub fn try_borrow_mut<F: FnOnce(&mut [u8]) -> usize>(
70 &self,
71 f: F,
72 ) -> Result<(), AsyncLeaseNotReady> {
73 if let AsyncLeaseState::Waiting(data, len, waker) = self.state.take() {
74 let buf = unsafe { core::slice::from_raw_parts_mut(data, len) };
75 let len = f(buf);
76 self.state.set(AsyncLeaseState::Done(len));
77 waker.wake();
78 Ok(())
79 } else {
80 Err(AsyncLeaseNotReady {})
81 }
82 }
83
84 pub fn lend<'a>(&'a self, buf: &'a mut [u8]) -> AsyncLeaseFuture<'a> {
85 AsyncLeaseFuture {
86 buf,
87 state: &self.state,
88 }
89 }
90}
diff --git a/embassy-usb-hid/src/lib.rs b/embassy-usb-hid/src/lib.rs
index 43e678806..527f014f2 100644
--- a/embassy-usb-hid/src/lib.rs
+++ b/embassy-usb-hid/src/lib.rs
@@ -8,24 +8,21 @@
8pub(crate) mod fmt; 8pub(crate) mod fmt;
9 9
10use core::mem::MaybeUninit; 10use core::mem::MaybeUninit;
11use core::ops::Range;
11 12
12use async_lease::AsyncLease;
13use embassy::time::Duration; 13use embassy::time::Duration;
14use embassy_usb::driver::{EndpointOut, ReadError}; 14use embassy_usb::driver::EndpointOut;
15use embassy_usb::{ 15use embassy_usb::{
16 control::{ControlHandler, InResponse, OutResponse, Request, RequestType}, 16 control::{ControlHandler, InResponse, OutResponse, Request, RequestType},
17 driver::{Driver, Endpoint, EndpointIn, WriteError}, 17 driver::{Driver, Endpoint, EndpointIn, WriteError},
18 UsbDeviceBuilder, 18 UsbDeviceBuilder,
19}; 19};
20use futures_util::future::{select, Either}; 20
21use futures_util::pin_mut;
22#[cfg(feature = "usbd-hid")] 21#[cfg(feature = "usbd-hid")]
23use ssmarshal::serialize; 22use ssmarshal::serialize;
24#[cfg(feature = "usbd-hid")] 23#[cfg(feature = "usbd-hid")]
25use usbd_hid::descriptor::AsInputReport; 24use usbd_hid::descriptor::AsInputReport;
26 25
27mod async_lease;
28
29const USB_CLASS_HID: u8 = 0x03; 26const USB_CLASS_HID: u8 = 0x03;
30const USB_SUBCLASS_NONE: u8 = 0x00; 27const USB_SUBCLASS_NONE: u8 = 0x00;
31const USB_PROTOCOL_NONE: u8 = 0x00; 28const USB_PROTOCOL_NONE: u8 = 0x00;
@@ -64,14 +61,12 @@ impl ReportId {
64 61
65pub struct State<'a, const IN_N: usize, const OUT_N: usize> { 62pub struct State<'a, const IN_N: usize, const OUT_N: usize> {
66 control: MaybeUninit<Control<'a>>, 63 control: MaybeUninit<Control<'a>>,
67 lease: AsyncLease,
68} 64}
69 65
70impl<'a, const IN_N: usize, const OUT_N: usize> State<'a, IN_N, OUT_N> { 66impl<'a, const IN_N: usize, const OUT_N: usize> State<'a, IN_N, OUT_N> {
71 pub fn new() -> Self { 67 pub fn new() -> Self {
72 State { 68 State {
73 control: MaybeUninit::uninit(), 69 control: MaybeUninit::uninit(),
74 lease: AsyncLease::new(),
75 } 70 }
76 } 71 }
77} 72}
@@ -90,9 +85,9 @@ impl<'d, D: Driver<'d>, const IN_N: usize> HidClass<'d, D, (), IN_N> {
90 /// high performance uses, and a value of 255 is good for best-effort usecases. 85 /// high performance uses, and a value of 255 is good for best-effort usecases.
91 /// 86 ///
92 /// This allocates an IN endpoint only. 87 /// This allocates an IN endpoint only.
93 pub fn new( 88 pub fn new<const OUT_N: usize>(
94 builder: &mut UsbDeviceBuilder<'d, D>, 89 builder: &mut UsbDeviceBuilder<'d, D>,
95 state: &'d mut State<'d, IN_N, 0>, 90 state: &'d mut State<'d, IN_N, OUT_N>,
96 report_descriptor: &'static [u8], 91 report_descriptor: &'static [u8],
97 request_handler: Option<&'d dyn RequestHandler>, 92 request_handler: Option<&'d dyn RequestHandler>,
98 poll_ms: u8, 93 poll_ms: u8,
@@ -101,8 +96,7 @@ impl<'d, D: Driver<'d>, const IN_N: usize> HidClass<'d, D, (), IN_N> {
101 let ep_in = builder.alloc_interrupt_endpoint_in(max_packet_size, poll_ms); 96 let ep_in = builder.alloc_interrupt_endpoint_in(max_packet_size, poll_ms);
102 let control = state 97 let control = state
103 .control 98 .control
104 .write(Control::new(report_descriptor, None, request_handler)); 99 .write(Control::new(report_descriptor, request_handler));
105
106 control.build(builder, None, &ep_in); 100 control.build(builder, None, &ep_in);
107 101
108 Self { 102 Self {
@@ -144,20 +138,14 @@ impl<'d, D: Driver<'d>, const IN_N: usize, const OUT_N: usize>
144 let ep_out = builder.alloc_interrupt_endpoint_out(max_packet_size, poll_ms); 138 let ep_out = builder.alloc_interrupt_endpoint_out(max_packet_size, poll_ms);
145 let ep_in = builder.alloc_interrupt_endpoint_in(max_packet_size, poll_ms); 139 let ep_in = builder.alloc_interrupt_endpoint_in(max_packet_size, poll_ms);
146 140
147 let control = state.control.write(Control::new( 141 let control = state
148 report_descriptor, 142 .control
149 Some(&state.lease), 143 .write(Control::new(report_descriptor, request_handler));
150 request_handler,
151 ));
152
153 control.build(builder, Some(&ep_out), &ep_in); 144 control.build(builder, Some(&ep_out), &ep_in);
154 145
155 Self { 146 Self {
156 input: ReportWriter { ep_in }, 147 input: ReportWriter { ep_in },
157 output: ReportReader { 148 output: ReportReader { ep_out, offset: 0 },
158 ep_out,
159 lease: &state.lease,
160 },
161 } 149 }
162 } 150 }
163 151
@@ -178,7 +166,21 @@ pub struct ReportWriter<'d, D: Driver<'d>, const N: usize> {
178 166
179pub struct ReportReader<'d, D: Driver<'d>, const N: usize> { 167pub struct ReportReader<'d, D: Driver<'d>, const N: usize> {
180 ep_out: D::EndpointOut, 168 ep_out: D::EndpointOut,
181 lease: &'d AsyncLease, 169 offset: usize,
170}
171
172pub enum ReadError {
173 BufferOverflow,
174 Sync(Range<usize>),
175}
176
177impl From<embassy_usb::driver::ReadError> for ReadError {
178 fn from(val: embassy_usb::driver::ReadError) -> Self {
179 use embassy_usb::driver::ReadError::*;
180 match val {
181 BufferOverflow => ReadError::BufferOverflow,
182 }
183 }
182} 184}
183 185
184impl<'d, D: Driver<'d>, const N: usize> ReportWriter<'d, D, N> { 186impl<'d, D: Driver<'d>, const N: usize> ReportWriter<'d, D, N> {
@@ -216,31 +218,55 @@ impl<'d, D: Driver<'d>, const N: usize> ReportWriter<'d, D, N> {
216} 218}
217 219
218impl<'d, D: Driver<'d>, const N: usize> ReportReader<'d, D, N> { 220impl<'d, D: Driver<'d>, const N: usize> ReportReader<'d, D, N> {
219 pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize, ReadError> { 221 /// Starts a task to deliver output reports from the Interrupt Out pipe to
220 assert!(buf.len() >= N); 222 /// `handler`.
221 223 pub async fn run<T: RequestHandler>(mut self, handler: &T) -> ! {
222 // Wait until a packet is ready to read from the endpoint or a SET_REPORT control request is received 224 assert!(self.offset == 0);
223 { 225 let mut buf = [0; N];
224 let data_ready = self.ep_out.wait_data_ready(); 226 loop {
225 pin_mut!(data_ready); 227 match self.read(&mut buf).await {
226 match select(data_ready, self.lease.lend(buf)).await { 228 Ok(len) => { handler.set_report(ReportId::Out(0), &buf[0..len]); }
227 Either::Left(_) => (), 229 Err(ReadError::BufferOverflow) => warn!("Host sent output report larger than the configured maximum output report length ({})", N),
228 Either::Right((len, _)) => return Ok(len), 230 Err(ReadError::Sync(_)) => unreachable!(),
229 } 231 }
230 } 232 }
233 }
234
235 /// Reads an output report from the Interrupt Out pipe.
236 ///
237 /// **Note:** Any reports sent from the host over the control pipe will be
238 /// passed to [`RequestHandler::set_report()`] for handling. The application
239 /// is responsible for ensuring output reports from both pipes are handled
240 /// correctly.
241 ///
242 /// **Note:** If `N` > the maximum packet size of the endpoint (i.e. output
243 /// reports may be split across multiple packets) and this method's future
244 /// is dropped after some packets have been read, the next call to `read()`
245 /// will return a [`ReadError::SyncError()`]. The range in the sync error
246 /// indicates the portion `buf` that was filled by the current call to
247 /// `read()`. If the dropped future used the same `buf`, then `buf` will
248 /// contain the full report.
249 pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize, ReadError> {
250 assert!(buf.len() >= N);
231 251
232 // Read packets from the endpoint 252 // Read packets from the endpoint
233 let max_packet_size = usize::from(self.ep_out.info().max_packet_size); 253 let max_packet_size = usize::from(self.ep_out.info().max_packet_size);
234 let mut total = 0; 254 let starting_offset = self.offset;
235 for chunk in buf.chunks_mut(max_packet_size) { 255 for chunk in buf[starting_offset..].chunks_mut(max_packet_size) {
236 let size = self.ep_out.read(chunk).await?; 256 let size = self.ep_out.read(chunk).await?;
237 total += size; 257 self.offset += size;
238 if size < max_packet_size || total == N { 258 if size < max_packet_size || self.offset == N {
239 break; 259 break;
240 } 260 }
241 } 261 }
242 262
243 Ok(total) 263 let total = self.offset;
264 self.offset = 0;
265 if starting_offset > 0 {
266 Err(ReadError::Sync(starting_offset..total))
267 } else {
268 Ok(total)
269 }
244 } 270 }
245} 271}
246 272
@@ -254,10 +280,6 @@ pub trait RequestHandler {
254 } 280 }
255 281
256 /// Sets the value of report `id` to `data`. 282 /// Sets the value of report `id` to `data`.
257 ///
258 /// If an output endpoint has been allocated, output reports
259 /// are routed through [`HidClass::output()`]. Otherwise they
260 /// are sent here, along with input and feature reports.
261 fn set_report(&self, id: ReportId, data: &[u8]) -> OutResponse { 283 fn set_report(&self, id: ReportId, data: &[u8]) -> OutResponse {
262 let _ = (id, data); 284 let _ = (id, data);
263 OutResponse::Rejected 285 OutResponse::Rejected
@@ -266,8 +288,8 @@ pub trait RequestHandler {
266 /// Get the idle rate for `id`. 288 /// Get the idle rate for `id`.
267 /// 289 ///
268 /// If `id` is `None`, get the idle rate for all reports. Returning `None` 290 /// If `id` is `None`, get the idle rate for all reports. Returning `None`
269 /// will reject the control request. Any duration above 1.020 seconds or 0 291 /// will reject the control request. Any duration at or above 1.024 seconds
270 /// will be returned as an indefinite idle rate. 292 /// or below 4ms will be returned as an indefinite idle rate.
271 fn get_idle(&self, id: Option<ReportId>) -> Option<Duration> { 293 fn get_idle(&self, id: Option<ReportId>) -> Option<Duration> {
272 let _ = id; 294 let _ = id;
273 None 295 None
@@ -284,7 +306,6 @@ pub trait RequestHandler {
284 306
285struct Control<'d> { 307struct Control<'d> {
286 report_descriptor: &'static [u8], 308 report_descriptor: &'static [u8],
287 out_lease: Option<&'d AsyncLease>,
288 request_handler: Option<&'d dyn RequestHandler>, 309 request_handler: Option<&'d dyn RequestHandler>,
289 hid_descriptor: [u8; 9], 310 hid_descriptor: [u8; 9],
290} 311}
@@ -292,12 +313,10 @@ struct Control<'d> {
292impl<'a> Control<'a> { 313impl<'a> Control<'a> {
293 fn new( 314 fn new(
294 report_descriptor: &'static [u8], 315 report_descriptor: &'static [u8],
295 out_lease: Option<&'a AsyncLease>,
296 request_handler: Option<&'a dyn RequestHandler>, 316 request_handler: Option<&'a dyn RequestHandler>,
297 ) -> Self { 317 ) -> Self {
298 Control { 318 Control {
299 report_descriptor, 319 report_descriptor,
300 out_lease,
301 request_handler, 320 request_handler,
302 hid_descriptor: [ 321 hid_descriptor: [
303 // Length of buf inclusive of size prefix 322 // Length of buf inclusive of size prefix
@@ -370,7 +389,7 @@ impl<'d> ControlHandler for Control<'d> {
370 if let RequestType::Class = req.request_type { 389 if let RequestType::Class = req.request_type {
371 match req.request { 390 match req.request {
372 HID_REQ_SET_IDLE => { 391 HID_REQ_SET_IDLE => {
373 if let Some(handler) = self.request_handler.as_ref() { 392 if let Some(handler) = self.request_handler {
374 let id = req.value as u8; 393 let id = req.value as u8;
375 let id = (id != 0).then(|| ReportId::In(id)); 394 let id = (id != 0).then(|| ReportId::In(id));
376 let dur = u64::from(req.value >> 8); 395 let dur = u64::from(req.value >> 8);
@@ -383,25 +402,8 @@ impl<'d> ControlHandler for Control<'d> {
383 } 402 }
384 OutResponse::Accepted 403 OutResponse::Accepted
385 } 404 }
386 HID_REQ_SET_REPORT => match ( 405 HID_REQ_SET_REPORT => match (ReportId::try_from(req.value), self.request_handler) {
387 ReportId::try_from(req.value), 406 (Ok(id), Some(handler)) => handler.set_report(id, data),
388 self.out_lease,
389 self.request_handler.as_ref(),
390 ) {
391 (Ok(ReportId::Out(_)), Some(lease), _) => {
392 match lease.try_borrow_mut(|buf| {
393 let len = buf.len().min(data.len());
394 buf[0..len].copy_from_slice(&data[0..len]);
395 len
396 }) {
397 Ok(()) => OutResponse::Accepted,
398 Err(_) => {
399 warn!("SET_REPORT received for output report with no reader listening.");
400 OutResponse::Rejected
401 }
402 }
403 }
404 (Ok(id), _, Some(handler)) => handler.set_report(id, data),
405 _ => OutResponse::Rejected, 407 _ => OutResponse::Rejected,
406 }, 408 },
407 HID_REQ_SET_PROTOCOL => { 409 HID_REQ_SET_PROTOCOL => {
@@ -429,10 +431,7 @@ impl<'d> ControlHandler for Control<'d> {
429 }, 431 },
430 (RequestType::Class, HID_REQ_GET_REPORT) => { 432 (RequestType::Class, HID_REQ_GET_REPORT) => {
431 let size = match ReportId::try_from(req.value) { 433 let size = match ReportId::try_from(req.value) {
432 Ok(id) => self 434 Ok(id) => self.request_handler.and_then(|x| x.get_report(id, buf)),
433 .request_handler
434 .as_ref()
435 .and_then(|x| x.get_report(id, buf)),
436 Err(_) => None, 435 Err(_) => None,
437 }; 436 };
438 437
@@ -443,7 +442,7 @@ impl<'d> ControlHandler for Control<'d> {
443 } 442 }
444 } 443 }
445 (RequestType::Class, HID_REQ_GET_IDLE) => { 444 (RequestType::Class, HID_REQ_GET_IDLE) => {
446 if let Some(handler) = self.request_handler.as_ref() { 445 if let Some(handler) = self.request_handler {
447 let id = req.value as u8; 446 let id = req.value as u8;
448 let id = (id != 0).then(|| ReportId::In(id)); 447 let id = (id != 0).then(|| ReportId::In(id));
449 if let Some(dur) = handler.get_idle(id) { 448 if let Some(dur) = handler.get_idle(id) {
diff --git a/embassy-usb/src/driver.rs b/embassy-usb/src/driver.rs
index 03e39b8c9..82b59bd1e 100644
--- a/embassy-usb/src/driver.rs
+++ b/embassy-usb/src/driver.rs
@@ -122,20 +122,12 @@ pub trait EndpointOut: Endpoint {
122 type ReadFuture<'a>: Future<Output = Result<usize, ReadError>> + 'a 122 type ReadFuture<'a>: Future<Output = Result<usize, ReadError>> + 'a
123 where 123 where
124 Self: 'a; 124 Self: 'a;
125 type DataReadyFuture<'a>: Future<Output = ()> + 'a
126 where
127 Self: 'a;
128 125
129 /// Reads a single packet of data from the endpoint, and returns the actual length of 126 /// Reads a single packet of data from the endpoint, and returns the actual length of
130 /// the packet. 127 /// the packet.
131 /// 128 ///
132 /// This should also clear any NAK flags and prepare the endpoint to receive the next packet. 129 /// This should also clear any NAK flags and prepare the endpoint to receive the next packet.
133 fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::ReadFuture<'a>; 130 fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::ReadFuture<'a>;
134
135 /// Waits until a packet of data is ready to be read from the endpoint.
136 ///
137 /// A call to[`read()`](Self::read()) after this future completes should not block.
138 fn wait_data_ready<'a>(&'a mut self) -> Self::DataReadyFuture<'a>;
139} 131}
140 132
141pub trait ControlPipe { 133pub trait ControlPipe {
diff --git a/examples/nrf/src/bin/usb_hid.rs b/examples/nrf/src/bin/usb_hid.rs
index 5253f225d..6ffb1fd40 100644
--- a/examples/nrf/src/bin/usb_hid.rs
+++ b/examples/nrf/src/bin/usb_hid.rs
@@ -52,7 +52,7 @@ async fn main(_spawner: Spawner, p: Peripherals) {
52 let mut control_buf = [0; 16]; 52 let mut control_buf = [0; 16];
53 let request_handler = MyRequestHandler {}; 53 let request_handler = MyRequestHandler {};
54 54
55 let mut state = State::<5, 0>::new(); 55 let mut control = State::<5, 0>::new();
56 56
57 let mut builder = UsbDeviceBuilder::new( 57 let mut builder = UsbDeviceBuilder::new(
58 driver, 58 driver,
@@ -66,7 +66,7 @@ async fn main(_spawner: Spawner, p: Peripherals) {
66 // Create classes on the builder. 66 // Create classes on the builder.
67 let mut hid = HidClass::new( 67 let mut hid = HidClass::new(
68 &mut builder, 68 &mut builder,
69 &mut state, 69 &mut control,
70 MouseReport::desc(), 70 MouseReport::desc(),
71 Some(&request_handler), 71 Some(&request_handler),
72 60, 72 60,