aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authoralexmoon <[email protected]>2022-04-05 17:23:46 -0400
committerDario Nieuwenhuis <[email protected]>2022-04-06 05:38:11 +0200
commita1754ac8a820d9cae97cf214969faf3090b37c76 (patch)
treecf6f9d3fcb0a41d175252521fec09b80085a7180
parent22a47aeeb2bc9d459a6e83414632890164a7b448 (diff)
embassy-usb-hid bug fixes
-rw-r--r--embassy-usb-hid/src/lib.rs87
-rw-r--r--examples/nrf/src/bin/usb_hid_keyboard.rs4
2 files changed, 70 insertions, 21 deletions
diff --git a/embassy-usb-hid/src/lib.rs b/embassy-usb-hid/src/lib.rs
index 996de6a5b..f50c5f8cb 100644
--- a/embassy-usb-hid/src/lib.rs
+++ b/embassy-usb-hid/src/lib.rs
@@ -9,6 +9,7 @@ pub(crate) mod fmt;
9 9
10use core::mem::MaybeUninit; 10use core::mem::MaybeUninit;
11use core::ops::Range; 11use core::ops::Range;
12use core::sync::atomic::{AtomicUsize, Ordering};
12 13
13use embassy::time::Duration; 14use embassy::time::Duration;
14use embassy_usb::driver::EndpointOut; 15use embassy_usb::driver::EndpointOut;
@@ -61,12 +62,14 @@ impl ReportId {
61 62
62pub struct State<'a, const IN_N: usize, const OUT_N: usize> { 63pub struct State<'a, const IN_N: usize, const OUT_N: usize> {
63 control: MaybeUninit<Control<'a>>, 64 control: MaybeUninit<Control<'a>>,
65 out_report_offset: AtomicUsize,
64} 66}
65 67
66impl<'a, const IN_N: usize, const OUT_N: usize> State<'a, IN_N, OUT_N> { 68impl<'a, const IN_N: usize, const OUT_N: usize> State<'a, IN_N, OUT_N> {
67 pub fn new() -> Self { 69 pub fn new() -> Self {
68 State { 70 State {
69 control: MaybeUninit::uninit(), 71 control: MaybeUninit::uninit(),
72 out_report_offset: AtomicUsize::new(0),
70 } 73 }
71 } 74 }
72} 75}
@@ -94,9 +97,11 @@ impl<'d, D: Driver<'d>, const IN_N: usize> HidClass<'d, D, (), IN_N> {
94 max_packet_size: u16, 97 max_packet_size: u16,
95 ) -> Self { 98 ) -> Self {
96 let ep_in = builder.alloc_interrupt_endpoint_in(max_packet_size, poll_ms); 99 let ep_in = builder.alloc_interrupt_endpoint_in(max_packet_size, poll_ms);
97 let control = state 100 let control = state.control.write(Control::new(
98 .control 101 report_descriptor,
99 .write(Control::new(report_descriptor, request_handler)); 102 request_handler,
103 &state.out_report_offset,
104 ));
100 control.build(builder, None, &ep_in); 105 control.build(builder, None, &ep_in);
101 106
102 Self { 107 Self {
@@ -138,14 +143,19 @@ impl<'d, D: Driver<'d>, const IN_N: usize, const OUT_N: usize>
138 let ep_out = builder.alloc_interrupt_endpoint_out(max_packet_size, poll_ms); 143 let ep_out = builder.alloc_interrupt_endpoint_out(max_packet_size, poll_ms);
139 let ep_in = builder.alloc_interrupt_endpoint_in(max_packet_size, poll_ms); 144 let ep_in = builder.alloc_interrupt_endpoint_in(max_packet_size, poll_ms);
140 145
141 let control = state 146 let control = state.control.write(Control::new(
142 .control 147 report_descriptor,
143 .write(Control::new(report_descriptor, request_handler)); 148 request_handler,
149 &state.out_report_offset,
150 ));
144 control.build(builder, Some(&ep_out), &ep_in); 151 control.build(builder, Some(&ep_out), &ep_in);
145 152
146 Self { 153 Self {
147 input: ReportWriter { ep_in }, 154 input: ReportWriter { ep_in },
148 output: ReportReader { ep_out, offset: 0 }, 155 output: ReportReader {
156 ep_out,
157 offset: &state.out_report_offset,
158 },
149 } 159 }
150 } 160 }
151 161
@@ -166,7 +176,7 @@ pub struct ReportWriter<'d, D: Driver<'d>, const N: usize> {
166 176
167pub struct ReportReader<'d, D: Driver<'d>, const N: usize> { 177pub struct ReportReader<'d, D: Driver<'d>, const N: usize> {
168 ep_out: D::EndpointOut, 178 ep_out: D::EndpointOut,
169 offset: usize, 179 offset: &'d AtomicUsize,
170} 180}
171 181
172#[derive(Debug, Clone, PartialEq, Eq)] 182#[derive(Debug, Clone, PartialEq, Eq)]
@@ -188,6 +198,11 @@ impl From<embassy_usb::driver::ReadError> for ReadError {
188} 198}
189 199
190impl<'d, D: Driver<'d>, const N: usize> ReportWriter<'d, D, N> { 200impl<'d, D: Driver<'d>, const N: usize> ReportWriter<'d, D, N> {
201 /// Waits for the interrupt in endpoint to be enabled.
202 pub async fn ready(&mut self) -> () {
203 self.ep_in.wait_enabled().await
204 }
205
191 /// Tries to write an input report by serializing the given report structure. 206 /// Tries to write an input report by serializing the given report structure.
192 /// 207 ///
193 /// Panics if no endpoint is available. 208 /// Panics if no endpoint is available.
@@ -222,14 +237,27 @@ impl<'d, D: Driver<'d>, const N: usize> ReportWriter<'d, D, N> {
222} 237}
223 238
224impl<'d, D: Driver<'d>, const N: usize> ReportReader<'d, D, N> { 239impl<'d, D: Driver<'d>, const N: usize> ReportReader<'d, D, N> {
240 /// Waits for the interrupt out endpoint to be enabled.
241 pub async fn ready(&mut self) -> () {
242 self.ep_out.wait_enabled().await
243 }
244
225 /// Starts a task to deliver output reports from the Interrupt Out pipe to 245 /// Starts a task to deliver output reports from the Interrupt Out pipe to
226 /// `handler`. 246 /// `handler`.
227 pub async fn run<T: RequestHandler>(mut self, handler: &T) -> ! { 247 ///
228 assert!(self.offset == 0); 248 /// Terminates when the interface becomes disabled.
249 ///
250 /// If `use_report_ids` is true, the first byte of the report will be used as
251 /// the `ReportId` value. Otherwise the `ReportId` value will be 0.
252 pub async fn run<T: RequestHandler>(mut self, use_report_ids: bool, handler: &T) -> ! {
253 let offset = self.offset.load(Ordering::Acquire);
254 assert!(offset == 0);
229 let mut buf = [0; N]; 255 let mut buf = [0; N];
230 loop { 256 loop {
231 match self.read(&mut buf).await { 257 match self.read(&mut buf).await {
232 Ok(len) => { handler.set_report(ReportId::Out(0), &buf[0..len]); } 258 Ok(len) => {
259 let id = if use_report_ids { buf[0] } else { 0 };
260 handler.set_report(ReportId::Out(id), &buf[..len]); }
233 Err(ReadError::BufferOverflow) => warn!("Host sent output report larger than the configured maximum output report length ({})", N), 261 Err(ReadError::BufferOverflow) => warn!("Host sent output report larger than the configured maximum output report length ({})", N),
234 Err(ReadError::Disabled) => self.ep_out.wait_enabled().await, 262 Err(ReadError::Disabled) => self.ep_out.wait_enabled().await,
235 Err(ReadError::Sync(_)) => unreachable!(), 263 Err(ReadError::Sync(_)) => unreachable!(),
@@ -257,17 +285,33 @@ impl<'d, D: Driver<'d>, const N: usize> ReportReader<'d, D, N> {
257 285
258 // Read packets from the endpoint 286 // Read packets from the endpoint
259 let max_packet_size = usize::from(self.ep_out.info().max_packet_size); 287 let max_packet_size = usize::from(self.ep_out.info().max_packet_size);
260 let starting_offset = self.offset; 288 let starting_offset = self.offset.load(Ordering::Acquire);
261 for chunk in buf[starting_offset..].chunks_mut(max_packet_size) { 289 let mut total = starting_offset;
262 let size = self.ep_out.read(chunk).await?; 290 loop {
263 self.offset += size; 291 for chunk in buf[starting_offset..N].chunks_mut(max_packet_size) {
264 if size < max_packet_size || self.offset == N { 292 match self.ep_out.read(chunk).await {
293 Ok(size) => {
294 total += size;
295 if size < max_packet_size || total == N {
296 self.offset.store(0, Ordering::Release);
297 break;
298 } else {
299 self.offset.store(total, Ordering::Release);
300 }
301 }
302 Err(err) => {
303 self.offset.store(0, Ordering::Release);
304 return Err(err.into());
305 }
306 }
307 }
308
309 // Some hosts may send ZLPs even when not required by the HID spec, so we'll loop as long as total == 0.
310 if total > 0 {
265 break; 311 break;
266 } 312 }
267 } 313 }
268 314
269 let total = self.offset;
270 self.offset = 0;
271 if starting_offset > 0 { 315 if starting_offset > 0 {
272 Err(ReadError::Sync(starting_offset..total)) 316 Err(ReadError::Sync(starting_offset..total))
273 } else { 317 } else {
@@ -313,6 +357,7 @@ pub trait RequestHandler {
313struct Control<'d> { 357struct Control<'d> {
314 report_descriptor: &'static [u8], 358 report_descriptor: &'static [u8],
315 request_handler: Option<&'d dyn RequestHandler>, 359 request_handler: Option<&'d dyn RequestHandler>,
360 out_report_offset: &'d AtomicUsize,
316 hid_descriptor: [u8; 9], 361 hid_descriptor: [u8; 9],
317} 362}
318 363
@@ -320,10 +365,12 @@ impl<'a> Control<'a> {
320 fn new( 365 fn new(
321 report_descriptor: &'static [u8], 366 report_descriptor: &'static [u8],
322 request_handler: Option<&'a dyn RequestHandler>, 367 request_handler: Option<&'a dyn RequestHandler>,
368 out_report_offset: &'a AtomicUsize,
323 ) -> Self { 369 ) -> Self {
324 Control { 370 Control {
325 report_descriptor, 371 report_descriptor,
326 request_handler, 372 request_handler,
373 out_report_offset,
327 hid_descriptor: [ 374 hid_descriptor: [
328 // Length of buf inclusive of size prefix 375 // Length of buf inclusive of size prefix
329 9, 376 9,
@@ -388,7 +435,9 @@ impl<'a> Control<'a> {
388} 435}
389 436
390impl<'d> ControlHandler for Control<'d> { 437impl<'d> ControlHandler for Control<'d> {
391 fn reset(&mut self) {} 438 fn reset(&mut self) {
439 self.out_report_offset.store(0, Ordering::Release);
440 }
392 441
393 fn control_out(&mut self, req: embassy_usb::control::Request, data: &[u8]) -> OutResponse { 442 fn control_out(&mut self, req: embassy_usb::control::Request, data: &[u8]) -> OutResponse {
394 trace!("HID control_out {:?} {=[u8]:x}", req, data); 443 trace!("HID control_out {:?} {=[u8]:x}", req, data);
diff --git a/examples/nrf/src/bin/usb_hid_keyboard.rs b/examples/nrf/src/bin/usb_hid_keyboard.rs
index af70a9a60..51136292f 100644
--- a/examples/nrf/src/bin/usb_hid_keyboard.rs
+++ b/examples/nrf/src/bin/usb_hid_keyboard.rs
@@ -54,7 +54,7 @@ async fn main(_spawner: Spawner, p: Peripherals) {
54 let mut control_buf = [0; 16]; 54 let mut control_buf = [0; 16];
55 let request_handler = MyRequestHandler {}; 55 let request_handler = MyRequestHandler {};
56 56
57 let mut state = State::<64, 64>::new(); 57 let mut state = State::<8, 1>::new();
58 58
59 let mut builder = UsbDeviceBuilder::new( 59 let mut builder = UsbDeviceBuilder::new(
60 driver, 60 driver,
@@ -117,7 +117,7 @@ async fn main(_spawner: Spawner, p: Peripherals) {
117 }; 117 };
118 118
119 let out_fut = async { 119 let out_fut = async {
120 hid_out.run(&MyRequestHandler {}).await; 120 hid_out.run(false, &request_handler).await;
121 }; 121 };
122 // Run everything concurrently. 122 // Run everything concurrently.
123 // If we had made everything `'static` above instead, we could do this using separate tasks instead. 123 // If we had made everything `'static` above instead, we could do this using separate tasks instead.