aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authoralexmoon <[email protected]>2022-03-30 14:17:15 -0400
committerDario Nieuwenhuis <[email protected]>2022-04-06 05:38:11 +0200
commitf5ba022257ccd9ddd371f1dcd10c0775cc5a3110 (patch)
treece48804d72e8a936a026a3479c49fede8a8352f9
parent77e0aca03b89ebc5f1e93b6c64b6c91ca10cedd1 (diff)
Refactor ControlPipe to use the typestate pattern for safety
-rw-r--r--embassy-usb/src/control.rs121
-rw-r--r--embassy-usb/src/lib.rs158
2 files changed, 157 insertions, 122 deletions
diff --git a/embassy-usb/src/control.rs b/embassy-usb/src/control.rs
index b5077c732..9f1115ff2 100644
--- a/embassy-usb/src/control.rs
+++ b/embassy-usb/src/control.rs
@@ -1,5 +1,7 @@
1use core::mem; 1use core::mem;
2 2
3use crate::descriptor::DescriptorWriter;
4use crate::driver::{self, ReadError};
3use crate::DEFAULT_ALTERNATE_SETTING; 5use crate::DEFAULT_ALTERNATE_SETTING;
4 6
5use super::types::*; 7use super::types::*;
@@ -191,3 +193,122 @@ pub trait ControlHandler {
191 InResponse::Accepted(&buf[0..2]) 193 InResponse::Accepted(&buf[0..2])
192 } 194 }
193} 195}
196
197/// Typestate representing a ControlPipe in the DATA IN stage
198#[derive(Debug)]
199#[cfg_attr(feature = "defmt", derive(defmt::Format))]
200pub(crate) struct DataInStage {
201 length: usize,
202}
203
204/// Typestate representing a ControlPipe in the DATA OUT stage
205#[derive(Debug)]
206#[cfg_attr(feature = "defmt", derive(defmt::Format))]
207pub(crate) struct DataOutStage {
208 length: usize,
209}
210
211/// Typestate representing a ControlPipe in the STATUS stage
212#[derive(Debug)]
213#[cfg_attr(feature = "defmt", derive(defmt::Format))]
214pub(crate) struct StatusStage {}
215
216#[derive(Debug)]
217#[cfg_attr(feature = "defmt", derive(defmt::Format))]
218pub(crate) enum Setup {
219 DataIn(Request, DataInStage),
220 DataOut(Request, DataOutStage),
221}
222
223pub(crate) struct ControlPipe<C: driver::ControlPipe> {
224 control: C,
225}
226
227impl<C: driver::ControlPipe> ControlPipe<C> {
228 pub(crate) fn new(control: C) -> Self {
229 ControlPipe { control }
230 }
231
232 pub(crate) async fn setup(&mut self) -> Setup {
233 let req = self.control.setup().await;
234 match (req.direction, req.length) {
235 (UsbDirection::Out, n) => Setup::DataOut(
236 req,
237 DataOutStage {
238 length: usize::from(n),
239 },
240 ),
241 (UsbDirection::In, n) => Setup::DataIn(
242 req,
243 DataInStage {
244 length: usize::from(n),
245 },
246 ),
247 }
248 }
249
250 pub(crate) async fn data_out<'a>(
251 &mut self,
252 buf: &'a mut [u8],
253 stage: DataOutStage,
254 ) -> Result<(&'a [u8], StatusStage), ReadError> {
255 if stage.length == 0 {
256 Ok((&[], StatusStage {}))
257 } else {
258 let req_length = stage.length;
259 let max_packet_size = self.control.max_packet_size();
260 let mut total = 0;
261
262 for chunk in buf.chunks_mut(max_packet_size) {
263 let size = self.control.data_out(chunk).await?;
264 total += size;
265 if size < max_packet_size || total == req_length {
266 break;
267 }
268 }
269
270 Ok((&buf[0..total], StatusStage {}))
271 }
272 }
273
274 pub(crate) async fn accept_in(&mut self, buf: &[u8], stage: DataInStage) {
275 #[cfg(feature = "defmt")]
276 debug!("control in accept {:x}", buf);
277 #[cfg(not(feature = "defmt"))]
278 debug!("control in accept {:x?}", buf);
279
280 let req_len = stage.length;
281 let len = buf.len().min(req_len);
282 let max_packet_size = self.control.max_packet_size();
283 let need_zlp = len != req_len && (len % usize::from(max_packet_size)) == 0;
284
285 let mut chunks = buf[0..len]
286 .chunks(max_packet_size)
287 .chain(need_zlp.then(|| -> &[u8] { &[] }));
288
289 while let Some(chunk) = chunks.next() {
290 self.control.data_in(chunk, chunks.size_hint().0 == 0).await;
291 }
292 }
293
294 pub(crate) async fn accept_in_writer(
295 &mut self,
296 req: Request,
297 stage: DataInStage,
298 f: impl FnOnce(&mut DescriptorWriter),
299 ) {
300 let mut buf = [0; 256];
301 let mut w = DescriptorWriter::new(&mut buf);
302 f(&mut w);
303 let pos = w.position().min(usize::from(req.length));
304 self.accept_in(&buf[..pos], stage).await
305 }
306
307 pub(crate) fn accept(&mut self, _: StatusStage) {
308 self.control.accept();
309 }
310
311 pub(crate) fn reject(&mut self) {
312 self.control.reject();
313 }
314}
diff --git a/embassy-usb/src/lib.rs b/embassy-usb/src/lib.rs
index 77a9c33be..067b5b07f 100644
--- a/embassy-usb/src/lib.rs
+++ b/embassy-usb/src/lib.rs
@@ -16,7 +16,7 @@ use heapless::Vec;
16 16
17use self::control::*; 17use self::control::*;
18use self::descriptor::*; 18use self::descriptor::*;
19use self::driver::*; 19use self::driver::{Bus, Driver, Event};
20use self::types::*; 20use self::types::*;
21use self::util::*; 21use self::util::*;
22 22
@@ -92,10 +92,7 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> {
92 Self { 92 Self {
93 bus: driver, 93 bus: driver,
94 config, 94 config,
95 control: ControlPipe { 95 control: ControlPipe::new(control),
96 control,
97 request: None,
98 },
99 device_descriptor, 96 device_descriptor,
100 config_descriptor, 97 config_descriptor,
101 bos_descriptor, 98 bos_descriptor,
@@ -134,57 +131,50 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> {
134 Either::Right(req) => { 131 Either::Right(req) => {
135 debug!("control request: {:x}", req); 132 debug!("control request: {:x}", req);
136 133
137 match req.direction { 134 match req {
138 UsbDirection::In => self.handle_control_in(req).await, 135 Setup::DataIn(req, stage) => self.handle_control_in(req, stage).await,
139 UsbDirection::Out => self.handle_control_out(req).await, 136 Setup::DataOut(req, stage) => self.handle_control_out(req, stage).await,
140 } 137 }
141 } 138 }
142 } 139 }
143 } 140 }
144 } 141 }
145 142
146 async fn handle_control_out(&mut self, req: Request) { 143 async fn handle_control_out(&mut self, req: Request, stage: DataOutStage) {
147 const CONFIGURATION_NONE_U16: u16 = CONFIGURATION_NONE as u16; 144 const CONFIGURATION_NONE_U16: u16 = CONFIGURATION_NONE as u16;
148 const CONFIGURATION_VALUE_U16: u16 = CONFIGURATION_VALUE as u16; 145 const CONFIGURATION_VALUE_U16: u16 = CONFIGURATION_VALUE as u16;
149 146
150 // If the request has a data state, we must read it. 147 let (data, stage) = match self.control.data_out(self.control_buf, stage).await {
151 let data = if req.length > 0 { 148 Ok(data) => data,
152 match self.control.data_out(self.control_buf).await { 149 Err(_) => {
153 Ok(data) => data, 150 warn!("usb: failed to read CONTROL OUT data stage.");
154 Err(_) => { 151 return;
155 warn!("usb: failed to read CONTROL OUT data stage.");
156 return;
157 }
158 } 152 }
159 } else {
160 &[]
161 }; 153 };
162 154
163 match (req.request_type, req.recipient) { 155 match (req.request_type, req.recipient) {
164 (RequestType::Standard, Recipient::Device) => match (req.request, req.value) { 156 (RequestType::Standard, Recipient::Device) => match (req.request, req.value) {
165 (Request::CLEAR_FEATURE, Request::FEATURE_DEVICE_REMOTE_WAKEUP) => { 157 (Request::CLEAR_FEATURE, Request::FEATURE_DEVICE_REMOTE_WAKEUP) => {
166 self.remote_wakeup_enabled = false; 158 self.remote_wakeup_enabled = false;
167 self.control.accept(); 159 self.control.accept(stage)
168 } 160 }
169 (Request::SET_FEATURE, Request::FEATURE_DEVICE_REMOTE_WAKEUP) => { 161 (Request::SET_FEATURE, Request::FEATURE_DEVICE_REMOTE_WAKEUP) => {
170 self.remote_wakeup_enabled = true; 162 self.remote_wakeup_enabled = true;
171 self.control.accept(); 163 self.control.accept(stage)
172 } 164 }
173 (Request::SET_ADDRESS, 1..=127) => { 165 (Request::SET_ADDRESS, 1..=127) => {
174 self.pending_address = req.value as u8; 166 self.pending_address = req.value as u8;
175 self.control.accept(); 167 self.control.accept(stage)
176 } 168 }
177 (Request::SET_CONFIGURATION, CONFIGURATION_VALUE_U16) => { 169 (Request::SET_CONFIGURATION, CONFIGURATION_VALUE_U16) => {
178 self.device_state = UsbDeviceState::Configured; 170 self.device_state = UsbDeviceState::Configured;
179 self.control.accept(); 171 self.control.accept(stage)
180 } 172 }
181 (Request::SET_CONFIGURATION, CONFIGURATION_NONE_U16) => match self.device_state { 173 (Request::SET_CONFIGURATION, CONFIGURATION_NONE_U16) => match self.device_state {
182 UsbDeviceState::Default => { 174 UsbDeviceState::Default => self.control.accept(stage),
183 self.control.accept();
184 }
185 _ => { 175 _ => {
186 self.device_state = UsbDeviceState::Addressed; 176 self.device_state = UsbDeviceState::Addressed;
187 self.control.accept(); 177 self.control.accept(stage)
188 } 178 }
189 }, 179 },
190 _ => self.control.reject(), 180 _ => self.control.reject(),
@@ -193,12 +183,12 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> {
193 (Request::SET_FEATURE, Request::FEATURE_ENDPOINT_HALT) => { 183 (Request::SET_FEATURE, Request::FEATURE_ENDPOINT_HALT) => {
194 let ep_addr = ((req.index as u8) & 0x8f).into(); 184 let ep_addr = ((req.index as u8) & 0x8f).into();
195 self.bus.set_stalled(ep_addr, true); 185 self.bus.set_stalled(ep_addr, true);
196 self.control.accept(); 186 self.control.accept(stage)
197 } 187 }
198 (Request::CLEAR_FEATURE, Request::FEATURE_ENDPOINT_HALT) => { 188 (Request::CLEAR_FEATURE, Request::FEATURE_ENDPOINT_HALT) => {
199 let ep_addr = ((req.index as u8) & 0x8f).into(); 189 let ep_addr = ((req.index as u8) & 0x8f).into();
200 self.bus.set_stalled(ep_addr, false); 190 self.bus.set_stalled(ep_addr, false);
201 self.control.accept(); 191 self.control.accept(stage)
202 } 192 }
203 _ => self.control.reject(), 193 _ => self.control.reject(),
204 }, 194 },
@@ -218,7 +208,7 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> {
218 _ => handler.control_out(req, data), 208 _ => handler.control_out(req, data),
219 }; 209 };
220 match response { 210 match response {
221 OutResponse::Accepted => self.control.accept(), 211 OutResponse::Accepted => self.control.accept(stage),
222 OutResponse::Rejected => self.control.reject(), 212 OutResponse::Rejected => self.control.reject(),
223 } 213 }
224 } 214 }
@@ -229,7 +219,7 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> {
229 } 219 }
230 } 220 }
231 221
232 async fn handle_control_in(&mut self, req: Request) { 222 async fn handle_control_in(&mut self, req: Request, stage: DataInStage) {
233 match (req.request_type, req.recipient) { 223 match (req.request_type, req.recipient) {
234 (RequestType::Standard, Recipient::Device) => match req.request { 224 (RequestType::Standard, Recipient::Device) => match req.request {
235 Request::GET_STATUS => { 225 Request::GET_STATUS => {
@@ -240,17 +230,15 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> {
240 if self.remote_wakeup_enabled { 230 if self.remote_wakeup_enabled {
241 status |= 0x0002; 231 status |= 0x0002;
242 } 232 }
243 self.control.accept_in(&status.to_le_bytes()).await; 233 self.control.accept_in(&status.to_le_bytes(), stage).await
244 }
245 Request::GET_DESCRIPTOR => {
246 self.handle_get_descriptor(req).await;
247 } 234 }
235 Request::GET_DESCRIPTOR => self.handle_get_descriptor(req, stage).await,
248 Request::GET_CONFIGURATION => { 236 Request::GET_CONFIGURATION => {
249 let status = match self.device_state { 237 let status = match self.device_state {
250 UsbDeviceState::Configured => CONFIGURATION_VALUE, 238 UsbDeviceState::Configured => CONFIGURATION_VALUE,
251 _ => CONFIGURATION_NONE, 239 _ => CONFIGURATION_NONE,
252 }; 240 };
253 self.control.accept_in(&status.to_le_bytes()).await; 241 self.control.accept_in(&status.to_le_bytes(), stage).await
254 } 242 }
255 _ => self.control.reject(), 243 _ => self.control.reject(),
256 }, 244 },
@@ -261,7 +249,7 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> {
261 if self.bus.is_stalled(ep_addr) { 249 if self.bus.is_stalled(ep_addr) {
262 status |= 0x0001; 250 status |= 0x0001;
263 } 251 }
264 self.control.accept_in(&status.to_le_bytes()).await; 252 self.control.accept_in(&status.to_le_bytes(), stage).await
265 } 253 }
266 _ => self.control.reject(), 254 _ => self.control.reject(),
267 }, 255 },
@@ -285,7 +273,7 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> {
285 }; 273 };
286 274
287 match response { 275 match response {
288 InResponse::Accepted(data) => self.control.accept_in(data).await, 276 InResponse::Accepted(data) => self.control.accept_in(data, stage).await,
289 InResponse::Rejected => self.control.reject(), 277 InResponse::Rejected => self.control.reject(),
290 } 278 }
291 } 279 }
@@ -296,17 +284,19 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> {
296 } 284 }
297 } 285 }
298 286
299 async fn handle_get_descriptor(&mut self, req: Request) { 287 async fn handle_get_descriptor(&mut self, req: Request, stage: DataInStage) {
300 let (dtype, index) = req.descriptor_type_index(); 288 let (dtype, index) = req.descriptor_type_index();
301 289
302 match dtype { 290 match dtype {
303 descriptor_type::BOS => self.control.accept_in(self.bos_descriptor).await, 291 descriptor_type::BOS => self.control.accept_in(self.bos_descriptor, stage).await,
304 descriptor_type::DEVICE => self.control.accept_in(self.device_descriptor).await, 292 descriptor_type::DEVICE => self.control.accept_in(self.device_descriptor, stage).await,
305 descriptor_type::CONFIGURATION => self.control.accept_in(self.config_descriptor).await, 293 descriptor_type::CONFIGURATION => {
294 self.control.accept_in(self.config_descriptor, stage).await
295 }
306 descriptor_type::STRING => { 296 descriptor_type::STRING => {
307 if index == 0 { 297 if index == 0 {
308 self.control 298 self.control
309 .accept_in_writer(req, |w| { 299 .accept_in_writer(req, stage, |w| {
310 w.write(descriptor_type::STRING, &lang_id::ENGLISH_US.to_le_bytes()); 300 w.write(descriptor_type::STRING, &lang_id::ENGLISH_US.to_le_bytes());
311 }) 301 })
312 .await 302 .await
@@ -324,7 +314,9 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> {
324 }; 314 };
325 315
326 if let Some(s) = s { 316 if let Some(s) = s {
327 self.control.accept_in_writer(req, |w| w.string(s)).await; 317 self.control
318 .accept_in_writer(req, stage, |w| w.string(s))
319 .await
328 } else { 320 } else {
329 self.control.reject() 321 self.control.reject()
330 } 322 }
@@ -334,81 +326,3 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> {
334 } 326 }
335 } 327 }
336} 328}
337
338struct ControlPipe<C: driver::ControlPipe> {
339 control: C,
340 request: Option<Request>,
341}
342
343impl<C: driver::ControlPipe> ControlPipe<C> {
344 async fn setup(&mut self) -> Request {
345 assert!(self.request.is_none());
346 let req = self.control.setup().await;
347 self.request = Some(req);
348 req
349 }
350
351 async fn data_out<'a>(&mut self, buf: &'a mut [u8]) -> Result<&'a [u8], ReadError> {
352 let req = self.request.unwrap();
353 assert_eq!(req.direction, UsbDirection::Out);
354 assert!(req.length > 0);
355 let req_length = usize::from(req.length);
356
357 let max_packet_size = self.control.max_packet_size();
358 let mut total = 0;
359
360 for chunk in buf.chunks_mut(max_packet_size) {
361 let size = self.control.data_out(chunk).await?;
362 total += size;
363 if size < max_packet_size || total == req_length {
364 break;
365 }
366 }
367
368 Ok(&buf[0..total])
369 }
370
371 async fn accept_in(&mut self, buf: &[u8]) -> () {
372 #[cfg(feature = "defmt")]
373 debug!("control in accept {:x}", buf);
374 #[cfg(not(feature = "defmt"))]
375 debug!("control in accept {:x?}", buf);
376 let req = unwrap!(self.request);
377 assert!(req.direction == UsbDirection::In);
378
379 let req_len = usize::from(req.length);
380 let len = buf.len().min(req_len);
381 let max_packet_size = self.control.max_packet_size();
382 let need_zlp = len != req_len && (len % usize::from(max_packet_size)) == 0;
383
384 let mut chunks = buf[0..len]
385 .chunks(max_packet_size)
386 .chain(need_zlp.then(|| -> &[u8] { &[] }));
387
388 while let Some(chunk) = chunks.next() {
389 self.control.data_in(chunk, chunks.size_hint().0 == 0).await;
390 }
391
392 self.request = None;
393 }
394
395 async fn accept_in_writer(&mut self, req: Request, f: impl FnOnce(&mut DescriptorWriter)) {
396 let mut buf = [0; 256];
397 let mut w = DescriptorWriter::new(&mut buf);
398 f(&mut w);
399 let pos = w.position().min(usize::from(req.length));
400 self.accept_in(&buf[..pos]).await;
401 }
402
403 fn accept(&mut self) {
404 assert!(self.request.is_some());
405 self.control.accept();
406 self.request = None;
407 }
408
409 fn reject(&mut self) {
410 assert!(self.request.is_some());
411 self.control.reject();
412 self.request = None;
413 }
414}