diff options
| -rw-r--r-- | embassy-stm32/src/usb/usb.rs | 45 |
1 files changed, 25 insertions, 20 deletions
diff --git a/embassy-stm32/src/usb/usb.rs b/embassy-stm32/src/usb/usb.rs index b9a16bbf1..31ab8f76d 100644 --- a/embassy-stm32/src/usb/usb.rs +++ b/embassy-stm32/src/usb/usb.rs | |||
| @@ -80,10 +80,10 @@ impl<T: Instance> interrupt::typelevel::Handler<T::Interrupt> for InterruptHandl | |||
| 80 | 80 | ||
| 81 | if istr.ctr() { | 81 | if istr.ctr() { |
| 82 | let index = istr.ep_id() as usize; | 82 | let index = istr.ep_id() as usize; |
| 83 | CTR_TRIGGERED[index].store(true, Ordering::Relaxed); | ||
| 84 | 83 | ||
| 85 | let mut epr = regs.epr(index).read(); | 84 | let mut epr = regs.epr(index).read(); |
| 86 | if epr.ctr_rx() { | 85 | if epr.ctr_rx() { |
| 86 | CTR_RX_TRIGGERED[index].store(true, Ordering::Relaxed); | ||
| 87 | if index == 0 && epr.setup() { | 87 | if index == 0 && epr.setup() { |
| 88 | EP0_SETUP.store(true, Ordering::Relaxed); | 88 | EP0_SETUP.store(true, Ordering::Relaxed); |
| 89 | } | 89 | } |
| @@ -91,6 +91,7 @@ impl<T: Instance> interrupt::typelevel::Handler<T::Interrupt> for InterruptHandl | |||
| 91 | EP_OUT_WAKERS[index].wake(); | 91 | EP_OUT_WAKERS[index].wake(); |
| 92 | } | 92 | } |
| 93 | if epr.ctr_tx() { | 93 | if epr.ctr_tx() { |
| 94 | CTR_TX_TRIGGERED[index].store(true, Ordering::Relaxed); | ||
| 94 | //trace!("EP {} TX", index); | 95 | //trace!("EP {} TX", index); |
| 95 | EP_IN_WAKERS[index].wake(); | 96 | EP_IN_WAKERS[index].wake(); |
| 96 | } | 97 | } |
| @@ -122,7 +123,8 @@ const USBRAM_ALIGN: usize = 4; | |||
| 122 | static BUS_WAKER: AtomicWaker = AtomicWaker::new(); | 123 | static BUS_WAKER: AtomicWaker = AtomicWaker::new(); |
| 123 | static EP0_SETUP: AtomicBool = AtomicBool::new(false); | 124 | static EP0_SETUP: AtomicBool = AtomicBool::new(false); |
| 124 | 125 | ||
| 125 | static CTR_TRIGGERED: [AtomicBool; EP_COUNT] = [const { AtomicBool::new(false) }; EP_COUNT]; | 126 | static CTR_TX_TRIGGERED: [AtomicBool; EP_COUNT] = [const { AtomicBool::new(false) }; EP_COUNT]; |
| 127 | static CTR_RX_TRIGGERED: [AtomicBool; EP_COUNT] = [const { AtomicBool::new(false) }; EP_COUNT]; | ||
| 126 | static EP_IN_WAKERS: [AtomicWaker; EP_COUNT] = [const { AtomicWaker::new() }; EP_COUNT]; | 128 | static EP_IN_WAKERS: [AtomicWaker; EP_COUNT] = [const { AtomicWaker::new() }; EP_COUNT]; |
| 127 | static EP_OUT_WAKERS: [AtomicWaker; EP_COUNT] = [const { AtomicWaker::new() }; EP_COUNT]; | 129 | static EP_OUT_WAKERS: [AtomicWaker; EP_COUNT] = [const { AtomicWaker::new() }; EP_COUNT]; |
| 128 | static IRQ_RESET: AtomicBool = AtomicBool::new(false); | 130 | static IRQ_RESET: AtomicBool = AtomicBool::new(false); |
| @@ -209,10 +211,12 @@ mod btable { | |||
| 209 | pub(super) fn write_in_rx<T: Instance>(_index: usize, _addr: u16) {} | 211 | pub(super) fn write_in_rx<T: Instance>(_index: usize, _addr: u16) {} |
| 210 | 212 | ||
| 211 | pub(super) fn write_in_len_tx<T: Instance>(index: usize, addr: u16, len: u16) { | 213 | pub(super) fn write_in_len_tx<T: Instance>(index: usize, addr: u16, len: u16) { |
| 214 | assert_eq!(addr & 0b11, 0); | ||
| 212 | USBRAM.mem(index * 2).write_value((addr as u32) | ((len as u32) << 16)); | 215 | USBRAM.mem(index * 2).write_value((addr as u32) | ((len as u32) << 16)); |
| 213 | } | 216 | } |
| 214 | 217 | ||
| 215 | pub(super) fn write_in_len_rx<T: Instance>(index: usize, addr: u16, len: u16) { | 218 | pub(super) fn write_in_len_rx<T: Instance>(index: usize, addr: u16, len: u16) { |
| 219 | assert_eq!(addr & 0b11, 0); | ||
| 216 | USBRAM | 220 | USBRAM |
| 217 | .mem(index * 2 + 1) | 221 | .mem(index * 2 + 1) |
| 218 | .write_value((addr as u32) | ((len as u32) << 16)); | 222 | .write_value((addr as u32) | ((len as u32) << 16)); |
| @@ -640,22 +644,25 @@ impl<'d, T: Instance> driver::Bus for Bus<'d, T> { | |||
| 640 | fn endpoint_set_enabled(&mut self, ep_addr: EndpointAddress, enabled: bool) { | 644 | fn endpoint_set_enabled(&mut self, ep_addr: EndpointAddress, enabled: bool) { |
| 641 | trace!("set_enabled {:?} {}", ep_addr, enabled); | 645 | trace!("set_enabled {:?} {}", ep_addr, enabled); |
| 642 | // This can race, so do a retry loop. | 646 | // This can race, so do a retry loop. |
| 643 | let reg = T::regs().epr(ep_addr.index() as _); | 647 | let epr = T::regs().epr(ep_addr.index() as _); |
| 644 | trace!("EPR before: {:04x}", reg.read().0); | 648 | trace!("EPR before: {:04x}", epr.read().0); |
| 645 | match ep_addr.direction() { | 649 | match ep_addr.direction() { |
| 646 | Direction::In => { | 650 | Direction::In => { |
| 647 | loop { | 651 | loop { |
| 648 | let want_stat = match enabled { | 652 | let want_stat = match enabled { |
| 649 | false => Stat::DISABLED, | 653 | false => Stat::DISABLED, |
| 650 | true => Stat::NAK, | 654 | true => match epr.read().ep_type() { |
| 655 | EpType::ISO => Stat::VALID, | ||
| 656 | _ => Stat::NAK, | ||
| 657 | }, | ||
| 651 | }; | 658 | }; |
| 652 | let r = reg.read(); | 659 | let r = epr.read(); |
| 653 | if r.stat_tx() == want_stat { | 660 | if r.stat_tx() == want_stat { |
| 654 | break; | 661 | break; |
| 655 | } | 662 | } |
| 656 | let mut w = invariant(r); | 663 | let mut w = invariant(r); |
| 657 | w.set_stat_tx(Stat::from_bits(r.stat_tx().to_bits() ^ want_stat.to_bits())); | 664 | w.set_stat_tx(Stat::from_bits(r.stat_tx().to_bits() ^ want_stat.to_bits())); |
| 658 | reg.write_value(w); | 665 | epr.write_value(w); |
| 659 | } | 666 | } |
| 660 | EP_IN_WAKERS[ep_addr.index()].wake(); | 667 | EP_IN_WAKERS[ep_addr.index()].wake(); |
| 661 | } | 668 | } |
| @@ -665,18 +672,18 @@ impl<'d, T: Instance> driver::Bus for Bus<'d, T> { | |||
| 665 | false => Stat::DISABLED, | 672 | false => Stat::DISABLED, |
| 666 | true => Stat::VALID, | 673 | true => Stat::VALID, |
| 667 | }; | 674 | }; |
| 668 | let r = reg.read(); | 675 | let r = epr.read(); |
| 669 | if r.stat_rx() == want_stat { | 676 | if r.stat_rx() == want_stat { |
| 670 | break; | 677 | break; |
| 671 | } | 678 | } |
| 672 | let mut w = invariant(r); | 679 | let mut w = invariant(r); |
| 673 | w.set_stat_rx(Stat::from_bits(r.stat_rx().to_bits() ^ want_stat.to_bits())); | 680 | w.set_stat_rx(Stat::from_bits(r.stat_rx().to_bits() ^ want_stat.to_bits())); |
| 674 | reg.write_value(w); | 681 | epr.write_value(w); |
| 675 | } | 682 | } |
| 676 | EP_OUT_WAKERS[ep_addr.index()].wake(); | 683 | EP_OUT_WAKERS[ep_addr.index()].wake(); |
| 677 | } | 684 | } |
| 678 | } | 685 | } |
| 679 | trace!("EPR after: {:04x}", reg.read().0); | 686 | trace!("EPR after: {:04x}", epr.read().0); |
| 680 | } | 687 | } |
| 681 | 688 | ||
| 682 | async fn enable(&mut self) {} | 689 | async fn enable(&mut self) {} |
| @@ -836,7 +843,8 @@ impl<'d, T: Instance> driver::EndpointOut for Endpoint<'d, T, Out> { | |||
| 836 | if self.info.ep_type == EndpointType::Isochronous { | 843 | if self.info.ep_type == EndpointType::Isochronous { |
| 837 | // The isochronous endpoint does not change its `STAT_RX` field to `NAK` when receiving a packet. | 844 | // The isochronous endpoint does not change its `STAT_RX` field to `NAK` when receiving a packet. |
| 838 | // Therefore, this instead waits until the `CTR` interrupt was triggered. | 845 | // Therefore, this instead waits until the `CTR` interrupt was triggered. |
| 839 | if matches!(stat, Stat::DISABLED) || CTR_TRIGGERED[index].load(Ordering::Relaxed) { | 846 | if matches!(stat, Stat::DISABLED) || CTR_RX_TRIGGERED[index].load(Ordering::Relaxed) { |
| 847 | assert!(matches!(stat, Stat::VALID | Stat::DISABLED)); | ||
| 840 | Poll::Ready(stat) | 848 | Poll::Ready(stat) |
| 841 | } else { | 849 | } else { |
| 842 | Poll::Pending | 850 | Poll::Pending |
| @@ -851,7 +859,7 @@ impl<'d, T: Instance> driver::EndpointOut for Endpoint<'d, T, Out> { | |||
| 851 | }) | 859 | }) |
| 852 | .await; | 860 | .await; |
| 853 | 861 | ||
| 854 | CTR_TRIGGERED[index].store(false, Ordering::Relaxed); | 862 | CTR_RX_TRIGGERED[index].store(false, Ordering::Relaxed); |
| 855 | 863 | ||
| 856 | if stat == Stat::DISABLED { | 864 | if stat == Stat::DISABLED { |
| 857 | return Err(EndpointError::Disabled); | 865 | return Err(EndpointError::Disabled); |
| @@ -895,18 +903,17 @@ impl<'d, T: Instance> driver::EndpointIn for Endpoint<'d, T, In> { | |||
| 895 | if buf.len() > self.info.max_packet_size as usize { | 903 | if buf.len() > self.info.max_packet_size as usize { |
| 896 | return Err(EndpointError::BufferOverflow); | 904 | return Err(EndpointError::BufferOverflow); |
| 897 | } | 905 | } |
| 898 | 906 | trace!("WRITE WAITING, buf.len() = {}", buf.len()); | |
| 899 | let index = self.info.addr.index(); | 907 | let index = self.info.addr.index(); |
| 900 | |||
| 901 | trace!("WRITE WAITING"); | ||
| 902 | let stat = poll_fn(|cx| { | 908 | let stat = poll_fn(|cx| { |
| 903 | EP_IN_WAKERS[index].register(cx.waker()); | 909 | EP_IN_WAKERS[index].register(cx.waker()); |
| 904 | let regs = T::regs(); | 910 | let regs = T::regs(); |
| 905 | let stat = regs.epr(index).read().stat_tx(); | 911 | let stat = regs.epr(index).read().stat_tx(); |
| 906 | if self.info.ep_type == EndpointType::Isochronous { | 912 | if self.info.ep_type == EndpointType::Isochronous { |
| 907 | // The isochronous endpoint does not change its `STAT_RX` field to `NAK` when receiving a packet. | 913 | // The isochronous endpoint does not change its `STAT_TX` field to `NAK` after sending a packet. |
| 908 | // Therefore, this instead waits until the `CTR` interrupt was triggered. | 914 | // Therefore, this instead waits until the `CTR` interrupt was triggered. |
| 909 | if matches!(stat, Stat::DISABLED) || CTR_TRIGGERED[index].load(Ordering::Relaxed) { | 915 | if matches!(stat, Stat::DISABLED) || CTR_TX_TRIGGERED[index].load(Ordering::Relaxed) { |
| 916 | assert!(matches!(stat, Stat::VALID | Stat::DISABLED)); | ||
| 910 | Poll::Ready(stat) | 917 | Poll::Ready(stat) |
| 911 | } else { | 918 | } else { |
| 912 | Poll::Pending | 919 | Poll::Pending |
| @@ -921,7 +928,7 @@ impl<'d, T: Instance> driver::EndpointIn for Endpoint<'d, T, In> { | |||
| 921 | }) | 928 | }) |
| 922 | .await; | 929 | .await; |
| 923 | 930 | ||
| 924 | CTR_TRIGGERED[index].store(false, Ordering::Relaxed); | 931 | CTR_TX_TRIGGERED[index].store(false, Ordering::Relaxed); |
| 925 | 932 | ||
| 926 | if stat == Stat::DISABLED { | 933 | if stat == Stat::DISABLED { |
| 927 | return Err(EndpointError::Disabled); | 934 | return Err(EndpointError::Disabled); |
| @@ -942,7 +949,6 @@ impl<'d, T: Instance> driver::EndpointIn for Endpoint<'d, T, In> { | |||
| 942 | 949 | ||
| 943 | self.write_data_double_buffered(buf, packet_buffer); | 950 | self.write_data_double_buffered(buf, packet_buffer); |
| 944 | 951 | ||
| 945 | let regs = T::regs(); | ||
| 946 | regs.epr(index).write(|w| { | 952 | regs.epr(index).write(|w| { |
| 947 | w.set_ep_type(convert_type(self.info.ep_type)); | 953 | w.set_ep_type(convert_type(self.info.ep_type)); |
| 948 | w.set_ea(self.info.addr.index() as _); | 954 | w.set_ea(self.info.addr.index() as _); |
| @@ -955,7 +961,6 @@ impl<'d, T: Instance> driver::EndpointIn for Endpoint<'d, T, In> { | |||
| 955 | w.set_ctr_rx(true); // don't clear | 961 | w.set_ctr_rx(true); // don't clear |
| 956 | w.set_ctr_tx(true); // don't clear | 962 | w.set_ctr_tx(true); // don't clear |
| 957 | }); | 963 | }); |
| 958 | |||
| 959 | trace!("WRITE OK"); | 964 | trace!("WRITE OK"); |
| 960 | 965 | ||
| 961 | Ok(()) | 966 | Ok(()) |
