From 521218ce06fbb7bd518eb6a069406936079e3ec2 Mon Sep 17 00:00:00 2001 From: diogo464 Date: Sat, 11 Oct 2025 11:34:59 +0100 Subject: initial working version --- src/dhcp.rs | 398 +++++++++++++++++++++++++++++++++++++++++++++++++++++----- src/main.rs | 408 +++++++++++++++++++++++++++++++++++++++--------------------- src/tftp.rs | 377 +++++++++++++++++++++++++++++++++++++++++-------------- 3 files changed, 919 insertions(+), 264 deletions(-) (limited to 'src') diff --git a/src/dhcp.rs b/src/dhcp.rs index 38cc8e4..51680d1 100644 --- a/src/dhcp.rs +++ b/src/dhcp.rs @@ -1,6 +1,7 @@ use std::{ io::{Cursor, Read as _, Result, Write}, net::Ipv4Addr, + str::FromStr, }; use crate::wire; @@ -8,6 +9,11 @@ use crate::wire; const MAGIC_COOKIE: [u8; 4] = [0x63, 0x82, 0x53, 0x63]; const FLAG_BROADCAST: u16 = 1 << 15; +pub const VENDOR_CLASS_PXE_CLIENT: &'static [u8] = b"PXEClient"; +pub const VENDOR_CLASS_PXE_SERVER: &'static [u8] = b"PXEServer"; + +pub const USER_CLASS_IPXE: &'static [u8] = b"iPXE"; + #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)] pub enum BootOp { #[default] @@ -29,6 +35,8 @@ impl From for u8 { } } +pub type HardwareAddress = [u8; 16]; + #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)] pub enum HardwareType { #[default] @@ -78,10 +86,11 @@ pub enum DhcpOption { End, MessageType(DhcpMessageType), ServerIdentifier(Ipv4Addr), - VendorClassIdentifier(String), + VendorClassIdentifier(Vec), TftpServerName(String), TftpFileName(String), - UserClassInformation(String), + UserClassInformation(Vec), + ClientSystemArchitecture(SystemArchitecture), ClientMachineIdentifier(Vec), Unknown { code: u8, data: Vec }, } @@ -95,6 +104,7 @@ impl DhcpOption { pub const CODE_TFTP_SERVER_NAME: u8 = 66; pub const CODE_TFTP_FILE_NAME: u8 = 67; pub const CODE_USER_CLASS_INFORMATION: u8 = 77; + pub const CODE_CLIENT_SYSTEM_ARCHITECTURE: u8 = 93; pub const CODE_CLIENT_MACHINE_IDENTIFIER: u8 = 97; pub fn code(&self) -> u8 { @@ -107,12 +117,294 @@ impl DhcpOption { DhcpOption::TftpServerName(_) => Self::CODE_TFTP_SERVER_NAME, DhcpOption::TftpFileName(_) => Self::CODE_TFTP_FILE_NAME, DhcpOption::UserClassInformation(_) => Self::CODE_USER_CLASS_INFORMATION, + DhcpOption::ClientSystemArchitecture(_) => Self::CODE_CLIENT_SYSTEM_ARCHITECTURE, DhcpOption::ClientMachineIdentifier(_) => Self::CODE_CLIENT_MACHINE_IDENTIFIER, DhcpOption::Unknown { code, .. } => *code, } } } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum SystemArchitecture { + IntelX86pc, + NECPC98, + EfiItanium, + DecAlpha, + ArcX86, + IntelLeanClient, + EfiIA32, + EfiBC, + EfiXscale, + EfiX86_64, + EfiARM32, + EfiARM64, + EfiARM32Http, + EfiARM64Http, + ARM32Uboot, + ARM64Uboot, + Unknown(u16), +} + +impl SystemArchitecture { + pub const CODE_INTEL_X86_PC: u16 = 0; + pub const CODE_NEC_PC98: u16 = 1; + pub const CODE_EFI_ITANIUM: u16 = 2; + pub const CODE_DEC_ALPHA: u16 = 3; + pub const CODE_ARC_X86: u16 = 4; + pub const CODE_INTEL_LEAN_CLIENT: u16 = 5; + pub const CODE_EFI_IA32: u16 = 6; + pub const CODE_EFI_BC: u16 = 7; + pub const CODE_EFI_XSCALE: u16 = 8; + pub const CODE_EFI_X86_64: u16 = 9; + pub const CODE_EFI_ARM32: u16 = 10; + pub const CODE_EFI_ARM64: u16 = 11; + pub const CODE_EFI_ARM32_HTTP: u16 = 18; + pub const CODE_EFI_ARM64_HTTP: u16 = 19; + pub const CODE_ARM32_UBOOT: u16 = 21; + pub const CODE_ARM64_UBOOT: u16 = 22; +} + +impl From for SystemArchitecture { + fn from(value: u16) -> Self { + match value { + Self::CODE_INTEL_X86_PC => SystemArchitecture::IntelX86pc, + Self::CODE_NEC_PC98 => SystemArchitecture::NECPC98, + Self::CODE_EFI_ITANIUM => SystemArchitecture::EfiItanium, + Self::CODE_DEC_ALPHA => SystemArchitecture::DecAlpha, + Self::CODE_ARC_X86 => SystemArchitecture::ArcX86, + Self::CODE_INTEL_LEAN_CLIENT => SystemArchitecture::IntelLeanClient, + Self::CODE_EFI_IA32 => SystemArchitecture::EfiIA32, + Self::CODE_EFI_BC => SystemArchitecture::EfiBC, + Self::CODE_EFI_XSCALE => SystemArchitecture::EfiXscale, + Self::CODE_EFI_X86_64 => SystemArchitecture::EfiX86_64, + Self::CODE_EFI_ARM32 => SystemArchitecture::EfiARM32, + Self::CODE_EFI_ARM64 => SystemArchitecture::EfiARM64, + Self::CODE_EFI_ARM32_HTTP => SystemArchitecture::EfiARM32Http, + Self::CODE_EFI_ARM64_HTTP => SystemArchitecture::EfiARM64Http, + Self::CODE_ARM32_UBOOT => SystemArchitecture::ARM32Uboot, + Self::CODE_ARM64_UBOOT => SystemArchitecture::ARM64Uboot, + _ => SystemArchitecture::Unknown(value), + } + } +} + +impl From for u16 { + fn from(value: SystemArchitecture) -> Self { + match value { + SystemArchitecture::IntelX86pc => SystemArchitecture::CODE_INTEL_X86_PC, + SystemArchitecture::NECPC98 => SystemArchitecture::CODE_NEC_PC98, + SystemArchitecture::EfiItanium => SystemArchitecture::CODE_EFI_ITANIUM, + SystemArchitecture::DecAlpha => SystemArchitecture::CODE_DEC_ALPHA, + SystemArchitecture::ArcX86 => SystemArchitecture::CODE_ARC_X86, + SystemArchitecture::IntelLeanClient => SystemArchitecture::CODE_INTEL_LEAN_CLIENT, + SystemArchitecture::EfiIA32 => SystemArchitecture::CODE_EFI_IA32, + SystemArchitecture::EfiBC => SystemArchitecture::CODE_EFI_BC, + SystemArchitecture::EfiXscale => SystemArchitecture::CODE_EFI_XSCALE, + SystemArchitecture::EfiX86_64 => SystemArchitecture::CODE_EFI_X86_64, + SystemArchitecture::EfiARM32 => SystemArchitecture::CODE_EFI_ARM32, + SystemArchitecture::EfiARM64 => SystemArchitecture::CODE_EFI_ARM64, + SystemArchitecture::EfiARM32Http => SystemArchitecture::CODE_EFI_ARM32_HTTP, + SystemArchitecture::EfiARM64Http => SystemArchitecture::CODE_EFI_ARM64_HTTP, + SystemArchitecture::ARM32Uboot => SystemArchitecture::CODE_ARM32_UBOOT, + SystemArchitecture::ARM64Uboot => SystemArchitecture::CODE_ARM64_UBOOT, + SystemArchitecture::Unknown(code) => code, + } + } +} + +impl FromStr for SystemArchitecture { + type Err = ::Err; + + fn from_str(s: &str) -> std::result::Result { + s.parse::().map(From::from) + } +} + +#[derive(Debug)] +pub struct InvalidPxeClassIdentifierKind(String); + +impl InvalidPxeClassIdentifierKind { + fn new(kind: impl Into) -> Self { + Self(kind.into()) + } +} + +impl std::fmt::Display for InvalidPxeClassIdentifierKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "invalid pxe class identifier kind '{}', expected 'PXEClient' or 'PXEServer'", + self.0 + ) + } +} + +impl std::error::Error for InvalidPxeClassIdentifierKind {} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum PxeClassIdentifierKind { + Client, + Server, +} + +impl PxeClassIdentifierKind { + pub const KIND_CLIENT: &'static str = "PXEClient"; + pub const KIND_SERVER: &'static str = "PXEServer"; +} + +impl FromStr for PxeClassIdentifierKind { + type Err = InvalidPxeClassIdentifierKind; + + fn from_str(s: &str) -> std::result::Result { + match s { + Self::KIND_CLIENT => Ok(Self::Client), + Self::KIND_SERVER => Ok(Self::Server), + _ => Err(InvalidPxeClassIdentifierKind::new(s)), + } + } +} + +#[derive(Debug)] +pub struct InvalidPxeClassIdentifier(String, String); + +impl InvalidPxeClassIdentifier { + fn new(class: impl Into, reason: impl Into) -> Self { + Self(class.into(), reason.into()) + } +} + +impl std::fmt::Display for InvalidPxeClassIdentifier { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "invalid pxe class identifier '{}': {}", self.0, self.1) + } +} + +impl std::error::Error for InvalidPxeClassIdentifier {} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum PxeClassIdentifier { + Client(PxeClassIdentifierClient), + Server(PxeClassIdentifierServer), +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct PxeClassIdentifierServer; + +impl std::fmt::Display for PxeClassIdentifierServer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("PXEServer") + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct PxeClassIdentifierClient { + pub architecture: SystemArchitecture, + pub undi_major: u16, + pub undi_minor: u16, +} + +impl std::fmt::Display for PxeClassIdentifierClient { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "PXEClient:Arch:{:05}:UNDI:{:03}{:03}", + u16::from(self.architecture), + self.undi_major, + self.undi_minor + ) + } +} + +impl std::fmt::Display for PxeClassIdentifier { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + PxeClassIdentifier::Client(client) => client.fmt(f), + PxeClassIdentifier::Server(server) => server.fmt(f), + } + } +} + +impl TryFrom<&[u8]> for PxeClassIdentifier { + type Error = InvalidPxeClassIdentifier; + + fn try_from(value: &[u8]) -> std::result::Result { + let str = std::str::from_utf8(value).map_err(|err| { + InvalidPxeClassIdentifier::new( + format!("{value:?}"), + format!("invalid utf-8 string: {err}"), + ) + })?; + str.parse() + } +} + +impl FromStr for PxeClassIdentifier { + type Err = InvalidPxeClassIdentifier; + + fn from_str(s: &str) -> std::result::Result { + let mut parts = s.split(":"); + let make_err = |reason: String| InvalidPxeClassIdentifier::new(s, reason); + + let kind = match parts.next() { + Some(kind) => kind + .parse::() + .map_err(|err| make_err(err.to_string()))?, + None => return Err(make_err("missing class kind".to_string())), + }; + + if kind == PxeClassIdentifierKind::Server { + if parts.next().is_some() { + return Err(make_err("invalid class".to_string())); + } + return Ok(Self::Server(PxeClassIdentifierServer)); + } + + if !parts.next().map(|s| s == "Arch").unwrap_or(false) { + return Err(make_err("invalid class".to_string())); + } + + let architecture = match parts.next() { + Some(arch) => arch + .parse::() + .map_err(|err| make_err(err.to_string()))?, + None => return Err(make_err("missing architecture".to_string())), + }; + + if !parts.next().map(|s| s == "UNDI").unwrap_or(false) { + return Err(make_err("invalid class".to_string())); + } + + let undi_str = match parts.next() { + Some(undi_str) => undi_str, + None => return Err(make_err("missing undi version".to_string())), + }; + + if undi_str.len() != 6 { + return Err(make_err("invalid undi version length".to_string())); + } + + let (undi_major_str, undi_minor_str) = undi_str.split_at_checked(3).unwrap(); + + let undi_major = undi_major_str + .parse::() + .map_err(|err| make_err(err.to_string()))?; + + let undi_minor = undi_minor_str + .parse::() + .map_err(|err| make_err(err.to_string()))?; + + if parts.next().is_some() { + return Err(make_err("invalid class".to_string())); + } + + Ok(Self::Client(PxeClassIdentifierClient { + architecture, + undi_major, + undi_minor, + })) + } +} + #[derive(Debug)] pub struct DhcpPacket { pub op: BootOp, @@ -125,7 +417,7 @@ pub struct DhcpPacket { pub yiaddr: Ipv4Addr, pub siaddr: Ipv4Addr, pub giaddr: Ipv4Addr, - pub chaddr: [u8; 16], + pub chaddr: HardwareAddress, // server host name pub sname: String, // boot file name @@ -158,11 +450,21 @@ impl DhcpPacket { pub fn new_boot( xid: u32, chaddr: [u8; 16], - client_uuid: Vec, + client_uuid: Option>, local_ip: Ipv4Addr, local_hostname: String, filename: String, ) -> Self { + let mut options = vec![ + DhcpOption::MessageType(DhcpMessageType::Offer), + DhcpOption::ServerIdentifier(local_ip), + DhcpOption::VendorClassIdentifier(b"PXEClient".to_vec()), + DhcpOption::TftpServerName(local_hostname), + DhcpOption::TftpFileName(filename), + ]; + if let Some(uuid) = client_uuid { + options.push(DhcpOption::ClientMachineIdentifier(uuid)); + } Self { op: BootOp::Reply, htype: HardwareType::Ethernet, @@ -177,25 +479,28 @@ impl DhcpPacket { chaddr, sname: Default::default(), file: Default::default(), - options: vec![ - DhcpOption::MessageType(DhcpMessageType::Offer), - DhcpOption::ServerIdentifier(local_ip), - DhcpOption::VendorClassIdentifier("PXEClient".to_string()), - DhcpOption::ClientMachineIdentifier(client_uuid), - DhcpOption::TftpServerName(local_hostname), - DhcpOption::TftpFileName(filename), - ], + options, } } pub fn new_boot_ack( xid: u32, chaddr: [u8; 16], - client_uuid: Vec, + client_uuid: Option>, local_ip: Ipv4Addr, hostname: String, filename: String, ) -> Self { + let mut options = vec![ + DhcpOption::MessageType(DhcpMessageType::Ack), + DhcpOption::ServerIdentifier(local_ip), + DhcpOption::VendorClassIdentifier(b"PXEClient".to_vec()), + DhcpOption::TftpServerName(hostname), + DhcpOption::TftpFileName(filename), + ]; + if let Some(uuid) = client_uuid { + options.push(DhcpOption::ClientMachineIdentifier(uuid)); + } Self { op: BootOp::Reply, htype: HardwareType::Ethernet, @@ -210,14 +515,7 @@ impl DhcpPacket { chaddr, sname: Default::default(), file: Default::default(), - options: vec![ - DhcpOption::MessageType(DhcpMessageType::Ack), - DhcpOption::ServerIdentifier(local_ip), - DhcpOption::VendorClassIdentifier("PXEClient".to_string()), - DhcpOption::ClientMachineIdentifier(client_uuid), - DhcpOption::TftpServerName(hostname), - DhcpOption::TftpFileName(filename), - ], + options, } } @@ -296,10 +594,23 @@ fn read_option(cursor: &mut Cursor<&[u8]>) -> Result { DhcpOption::CODE_PAD => DhcpOption::Pad, DhcpOption::CODE_END => DhcpOption::End, DhcpOption::CODE_VENDOR_CLASS_IDENTIFIER => { - DhcpOption::VendorClassIdentifier(read_len8_prefixed_string(cursor)?) + DhcpOption::VendorClassIdentifier(read_len8_prefixed_vec(cursor)?) } DhcpOption::CODE_USER_CLASS_INFORMATION => { - DhcpOption::UserClassInformation(read_len8_prefixed_string(cursor)?) + DhcpOption::UserClassInformation(read_len8_prefixed_vec(cursor)?) + } + DhcpOption::CODE_CLIENT_SYSTEM_ARCHITECTURE => { + let len = read_u8(cursor)?; + assert_eq!(len, 2); + + let mut buf = [0u8; 2]; + cursor.read_exact(&mut buf)?; + + let arch = SystemArchitecture::from(u16::from_be_bytes(buf)); + DhcpOption::ClientSystemArchitecture(arch) + } + DhcpOption::CODE_CLIENT_MACHINE_IDENTIFIER => { + DhcpOption::ClientMachineIdentifier(read_len8_prefixed_vec(cursor)?) } _ => { let len = read_u8(cursor)?; @@ -353,6 +664,20 @@ pub fn parse_packet(buf: &[u8]) -> Result { } pub fn write_packet(mut writer: W, packet: &DhcpPacket) -> Result<()> { + if packet.sname.len() >= 64 { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "sname cannot be longer than 64 bytes", + )); + } + + if packet.file.len() >= 128 { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "filename cannot be longer than 128 bytes", + )); + } + wire::write_u8(&mut writer, u8::from(packet.op))?; wire::write_u8(&mut writer, u8::from(packet.htype))?; wire::write_u8(&mut writer, packet.htype.hardware_len())?; @@ -365,10 +690,21 @@ pub fn write_packet(mut writer: W, packet: &DhcpPacket) -> Result<()> wire::write_ipv4(&mut writer, packet.siaddr)?; wire::write_ipv4(&mut writer, packet.giaddr)?; wire::write(&mut writer, &packet.chaddr)?; - //wire::write_null_terminated_string(&mut writer, &packet.sname)?; - //wire::write_null_terminated_string(&mut writer, &packet.file)?; - wire::write(&mut writer, &vec![0u8; 64])?; - wire::write(&mut writer, &vec![0u8; 128])?; + + let sname_bytes = packet.sname.as_bytes(); + wire::write(&mut writer, sname_bytes)?; + for _ in 0..(64 - sname_bytes.len()) { + wire::write_u8(&mut writer, 0)?; + } + + let file_bytes = packet.file.as_bytes(); + wire::write(&mut writer, file_bytes)?; + for _ in 0..(128 - file_bytes.len()) { + wire::write_u8(&mut writer, 0)?; + } + + // wire::write(&mut writer, &vec![0u8; 64])?; + // wire::write(&mut writer, &vec![0u8; 128])?; wire::write(&mut writer, &MAGIC_COOKIE)?; for option in &packet.options { write_option(&mut writer, option)?; @@ -390,12 +726,16 @@ pub fn write_option(mut writer: W, option: &DhcpOption) -> Result<()> wire::write_ipv4(&mut writer, *ip)?; } DhcpOption::VendorClassIdentifier(vendor_class) => { - write_option_len_prefixed_string(&mut writer, &vendor_class)? + write_option_len_prefixed_buf(&mut writer, &vendor_class)? } DhcpOption::TftpServerName(name) => write_option_len_prefixed_string(&mut writer, &name)?, DhcpOption::TftpFileName(name) => write_option_len_prefixed_string(&mut writer, &name)?, DhcpOption::UserClassInformation(user_class) => { - write_option_len_prefixed_string(&mut writer, &user_class)? + write_option_len_prefixed_buf(&mut writer, &user_class)? + } + DhcpOption::ClientSystemArchitecture(arch) => { + wire::write_u8(&mut writer, 2)?; + wire::write_u16(&mut writer, u16::from(*arch))?; } DhcpOption::ClientMachineIdentifier(identifier) => { write_option_len_prefixed_buf(&mut writer, &identifier)? diff --git a/src/main.rs b/src/main.rs index 51bbd77..c179ac0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,4 @@ +#![feature(gethostname)] #![feature(cursor_split)] pub mod dhcp; pub mod tftp; @@ -8,13 +9,273 @@ use std::{ net::{Ipv4Addr, SocketAddr, SocketAddrV4, UdpSocket}, }; +use clap::Parser; use ipnet::Ipv4Net; use crate::dhcp::{DhcpOption, DhcpPacket}; -const LOCAL_IPV4: Ipv4Addr = Ipv4Addr::new(192, 168, 1, 103); -const LOCAL_HOSTNAME: &'static str = "Diogos-Air"; +const BOOT_FILE_X64_BIOS: &'static str = "netboot.xyz.kpxe"; +const BOOT_FILE_X64_EFI: &'static str = "netboot.xyz.efi"; +const BOOT_FILE_A64_EFI: &'static str = "netboot.xyz-arm64.efi"; +const MENU_FILE: &'static str = "menu.ipxe"; +#[derive(Debug, Parser)] +struct Cli { + #[clap(long)] + hostname: Option, + + #[clap(long, default_value = "0.0.0.0")] + listen_address: Ipv4Addr, + + #[clap(long, default_value = "67")] + dhcp_port: u16, + + #[clap(long, default_value = "4011")] + proxy_dhcp_port: u16, + + #[clap(long, default_value = "69")] + tftp_port: u16, +} + +struct Context { + local_hostname: String, + local_address: Ipv4Addr, +} + +fn main() { + let cli = Cli::parse(); + + let dhcp_sockaddr = SocketAddrV4::new(cli.listen_address, cli.dhcp_port); + let pdhcp_sockaddr = SocketAddrV4::new(cli.listen_address, cli.proxy_dhcp_port); + let tftp_sockaddr = SocketAddrV4::new(cli.listen_address, cli.tftp_port); + + let hostname = match cli.hostname { + Some(hostname) => hostname, + None => { + let hostname = std::net::hostname().expect("unable to obtain local machine's hostname"); + hostname + .into_string() + .expect("unable to convert local machine's hostname to utf-8 string") + } + }; + let local_ip_address = if cli.listen_address == Ipv4Addr::UNSPECIFIED { + let interfaces = list_network_interfaces().expect("unable to list network interfaces"); + let mut chosen = None; + for interface in interfaces { + if interface.address.is_loopback() { + continue; + } + chosen = Some((interface.interface, interface.address)); + break; + } + + let (name, addr) = + chosen.expect("unable to find network interface with non-loopback IPv4 address"); + println!("using local address {} from interface {}", addr, name); + addr + } else { + cli.listen_address + }; + + println!("local hostname = {hostname}"); + println!("local address = {local_ip_address}"); + + let context = Context { + local_hostname: hostname, + local_address: local_ip_address, + }; + + let socket_dhcp = UdpSocket::bind(dhcp_sockaddr).unwrap(); + socket_dhcp.set_broadcast(true).unwrap(); + socket_dhcp.set_nonblocking(true).unwrap(); + + let socket_pdhcp = UdpSocket::bind(pdhcp_sockaddr).unwrap(); + socket_pdhcp.set_broadcast(true).unwrap(); + socket_pdhcp.set_nonblocking(true).unwrap(); + + let socket_tftp = UdpSocket::bind(tftp_sockaddr).unwrap(); + socket_tftp.set_broadcast(false).unwrap(); + socket_tftp.set_nonblocking(true).unwrap(); + + let tftp_filesystem = tftp::StaticFileSystem::new(&[ + (BOOT_FILE_X64_BIOS, include_bytes!("../netboot.xyz.kpxe")), + (BOOT_FILE_X64_EFI, include_bytes!("../netboot.xyz.efi")), + ( + BOOT_FILE_A64_EFI, + include_bytes!("../netboot.xyz-arm64.efi"), + ), + (MENU_FILE, include_bytes!("../menu.ipxe")), + ]); + let mut tftp_server = tftp::Server::default(); + + loop { + let mut buf = [0u8; 1500]; + + if let Ok((n, addr)) = socket_dhcp.recv_from(&mut buf) { + println!("Received {} bytes from {} on port 67", n, addr); + handle_packet(&context, &buf[..n], &socket_dhcp); + } + + if let Ok((n, addr)) = socket_pdhcp.recv_from(&mut buf) { + println!("Received {} bytes from {} on port 4011", n, addr); + handle_packet_4011(&context, &buf[..n], &socket_pdhcp, addr); + } + + if let Ok((n, addr)) = socket_tftp.recv_from(&mut buf) { + println!("Received {} bytes from {} on port 4011", n, addr); + match tftp_server.process(&tftp_filesystem, addr, &buf) { + tftp::ServerCommand::Send(tftp_packet) => { + let mut output = Vec::default(); + tftp_packet.write(&mut output).unwrap(); + socket_tftp.send_to(&output, addr).unwrap(); + } + tftp::ServerCommand::Ignore => {} + } + } + + std::thread::sleep(std::time::Duration::from_millis(1)); + } +} + +fn handle_packet(context: &Context, buf: &[u8], socket: &UdpSocket) { + let packet = match dhcp::parse_packet(buf) { + Ok(packet) => packet, + Err(err) => { + eprintln!("failed to parse DHCP packet: {err}"); + return; + } + }; + + println!("Parsed DHCP packet: XID={:08x}", packet.xid); + + // Check if it's a PXE client and extract client UUID + let mut pxe_class = None; + let mut client_uuid = None; + let mut is_ipxe = false; + + for option in &packet.options { + match option { + DhcpOption::VendorClassIdentifier(vendor_class) => { + if let Ok(class) = dhcp::PxeClassIdentifier::try_from(vendor_class.as_slice()) { + println!("{class}"); + pxe_class = Some(class); + } + } + DhcpOption::UserClassInformation(user_class) => { + if user_class == dhcp::USER_CLASS_IPXE { + is_ipxe = true; + } + } + DhcpOption::ClientMachineIdentifier(uuid) => { + client_uuid = Some(uuid.clone()); + } + _ => {} + } + } + + let pxe_client_class = match pxe_class { + Some(dhcp::PxeClassIdentifier::Client(class)) => class, + _ => { + println!("Not a PXE client, ignoring"); + return; + } + }; + + println!("Responding to PXE client with DHCPOFFER"); + let mut response_buf = Vec::default(); + let response = DhcpPacket::new_boot( + packet.xid, + packet.chaddr, + client_uuid, + context.local_address, + context.local_hostname.clone(), + match is_ipxe { + true => MENU_FILE.to_string(), + false => match pxe_client_class.architecture { + dhcp::SystemArchitecture::IntelX86pc => BOOT_FILE_X64_BIOS.to_string(), + dhcp::SystemArchitecture::EfiARM64 => BOOT_FILE_A64_EFI.to_string(), + dhcp::SystemArchitecture::EfiX86_64 | dhcp::SystemArchitecture::EfiBC => { + BOOT_FILE_X64_EFI.to_string() + } + _ => { + eprintln!( + "unsupported architecture {:?}", + pxe_client_class.architecture + ); + return; + } + }, + }, + ); + response.write(&mut response_buf).unwrap(); + socket + .send_to(&response_buf, SocketAddrV4::new(Ipv4Addr::BROADCAST, 68)) + .unwrap(); +} + +fn handle_packet_4011(context: &Context, buf: &[u8], socket: &UdpSocket, sender_addr: SocketAddr) { + let packet = match dhcp::parse_packet(buf) { + Ok(packet) => packet, + Err(err) => { + println!("Failed to parse packet on 4011: {}", err); + return; + } + }; + + println!("Parsed DHCP packet on 4011: XID={:08x}", packet.xid); + + // Extract client UUID + let mut client_uuid = None; + for option in &packet.options { + if let DhcpOption::ClientMachineIdentifier(uuid) = option { + client_uuid = Some(uuid.clone()); + break; + } + } + + let mut client_class = None; + for option in &packet.options { + if let DhcpOption::VendorClassIdentifier(vendor_class) = option { + if let Ok(dhcp::PxeClassIdentifier::Client(class)) = + dhcp::PxeClassIdentifier::try_from(vendor_class.as_slice()) + { + println!("{class}"); + client_class = Some(class); + } + } + } + let client_class = match client_class { + Some(class) => class, + None => return, + }; + + let file = match client_class.architecture { + dhcp::SystemArchitecture::IntelX86pc => BOOT_FILE_X64_BIOS.to_string(), + dhcp::SystemArchitecture::EfiARM64 => BOOT_FILE_A64_EFI.to_string(), + dhcp::SystemArchitecture::EfiX86_64 | dhcp::SystemArchitecture::EfiBC => { + BOOT_FILE_X64_EFI.to_string() + } + _ => { + eprintln!("unsupported architecture {:?}", client_class.architecture); + return; + } + }; + + println!("Responding with DHCPACK"); + let mut response_buf = Vec::default(); + let response = DhcpPacket::new_boot_ack( + packet.xid, + packet.chaddr, + client_uuid, + context.local_address, + context.local_hostname.clone(), + file, + ); + response.write(&mut response_buf).unwrap(); + socket.send_to(&response_buf, sender_addr).unwrap(); +} + +#[allow(unused)] #[derive(Debug, Clone)] struct InterfaceAddr { interface: String, @@ -68,146 +329,3 @@ fn list_network_interfaces() -> Result> { Ok(interfaces) } } - -fn main() { - let socket67 = UdpSocket::bind("0.0.0.0:67").unwrap(); - socket67.set_broadcast(true).unwrap(); - socket67.set_nonblocking(true).unwrap(); - - let socket4011 = UdpSocket::bind("0.0.0.0:4011").unwrap(); - socket4011.set_broadcast(true).unwrap(); - socket4011.set_nonblocking(true).unwrap(); - - std::thread::spawn(|| { - tftp::serve("tftp").unwrap(); - }); - - loop { - let mut buf = [0u8; 1500]; - - // Try port 67 first - if let Ok((n, addr)) = socket67.recv_from(&mut buf) { - println!("Received {} bytes from {} on port 67", n, addr); - handle_packet(&buf[..n], &socket67); - } else if let Ok((n, addr)) = socket4011.recv_from(&mut buf) { - println!("Received {} bytes from {} on port 4011", n, addr); - handle_packet_4011(&buf[..n], &socket4011, addr); - } else { - std::thread::sleep(std::time::Duration::from_millis(10)); - } - } -} - -fn handle_packet(buf: &[u8], socket: &UdpSocket) { - match dhcp::parse_packet(buf) { - Ok(packet) => { - println!("Parsed DHCP packet: XID={:08x}", packet.xid); - - // Check if it's a PXE client and extract client UUID - let mut is_pxe = false; - let mut client_uuid = None; - let mut is_ipxe = false; - - for option in &packet.options { - match option { - DhcpOption::VendorClassIdentifier(vendor_class) => { - println!("Vendor class: {}", vendor_class); - if vendor_class.contains("PXEClient") { - is_pxe = true; - } - } - DhcpOption::UserClassInformation(user_class) => { - println!("User class: {}", user_class); - is_ipxe = true; - } - DhcpOption::Unknown { code: 97, data } => { - println!("Found client machine identifier"); - client_uuid = Some(data.clone()); - } - _ => {} - } - } - - if is_pxe { - println!("Responding to PXE client with DHCPOFFER"); - let mut response_buf = Vec::default(); - let response = DhcpPacket::new_boot( - packet.xid, - packet.chaddr, - client_uuid.unwrap(), - LOCAL_IPV4, - LOCAL_HOSTNAME.to_string(), - match is_ipxe { - true => "test.ipxe".to_string(), - false => "ipxe.efi".to_string(), - }, - ); - response.write(&mut response_buf).unwrap(); - socket - .send_to(&response_buf, SocketAddrV4::new(Ipv4Addr::BROADCAST, 68)) - .unwrap(); - } else { - println!("Not a PXE client, ignoring"); - } - } - Err(e) => { - println!("Failed to parse packet: {}", e); - } - } -} - -fn handle_packet_4011(buf: &[u8], socket: &UdpSocket, sender_addr: SocketAddr) { - match dhcp::parse_packet(buf) { - Ok(packet) => { - println!("Parsed DHCP packet on 4011: XID={:08x}", packet.xid); - - // Extract client UUID - let mut client_uuid = None; - for option in &packet.options { - if let DhcpOption::Unknown { code: 97, data } = option { - client_uuid = Some(data.clone()); - break; - } - } - - println!("Responding with DHCPACK"); - let mut response_buf = Vec::default(); - let response = DhcpPacket::new_boot_ack( - packet.xid, - packet.chaddr, - client_uuid.unwrap(), - LOCAL_IPV4, - LOCAL_HOSTNAME.to_string(), - "ipxe.efi".to_string(), - ); - response.write(&mut response_buf).unwrap(); - socket.send_to(&response_buf, sender_addr).unwrap(); - } - Err(e) => { - println!("Failed to parse packet on 4011: {}", e); - } - } -} - -const DHCP_PACKET_PAYLOAD: &'static [u8] = &[ - 0x1, 0x1, 0x6, 0x0, 0xf1, 0x25, 0x7c, 0x21, 0x0, 0x0, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2b, 0x67, 0x3f, 0xda, 0x70, 0x0, 0x0, - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x63, 0x82, 0x53, 0x63, 0x35, 0x1, 0x1, 0x39, - 0x2, 0x5, 0xc0, 0x37, 0x23, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0xc, 0xd, 0xf, 0x11, 0x12, 0x16, - 0x17, 0x1c, 0x28, 0x29, 0x2a, 0x2b, 0x32, 0x33, 0x36, 0x3a, 0x3b, 0x3c, 0x42, 0x43, 0x61, 0x80, - 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x61, 0x11, 0x0, 0xcc, 0xfc, 0x32, 0x1b, 0xce, 0x2a, - 0xb2, 0x11, 0xa8, 0x5c, 0xb1, 0xac, 0x38, 0x38, 0x10, 0xf, 0x5e, 0x3, 0x1, 0x3, 0x10, 0x5d, - 0x2, 0x0, 0x7, 0x3c, 0x20, 0x50, 0x58, 0x45, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x3a, 0x41, - 0x72, 0x63, 0x68, 0x3a, 0x30, 0x30, 0x30, 0x30, 0x37, 0x3a, 0x55, 0x4e, 0x44, 0x49, 0x3a, 0x30, - 0x30, 0x33, 0x30, 0x31, 0x36, 0xff, -]; diff --git a/src/tftp.rs b/src/tftp.rs index d986a44..72bac22 100644 --- a/src/tftp.rs +++ b/src/tftp.rs @@ -1,7 +1,7 @@ use std::{ + collections::HashMap, io::{Cursor, Read as _, Result, Write}, - net::UdpSocket, - path::{Path, PathBuf}, + net::SocketAddr, str::FromStr, }; @@ -9,6 +9,9 @@ use crate::wire; pub const PORT: u16 = 69; +const DEFAULT_BLOCK_SIZE: u64 = 512; +const MAX_BLOCK_SIZE: usize = 2048; + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct InvalidTftpOp(u16); @@ -91,7 +94,7 @@ impl FromStr for TftpMode { #[derive(Debug)] pub enum TftpPacket { - Request(TftpRequestPacket), + ReadRequest(TftpReadRequestPacket), Data(TftpDataPacket), Ack(TftpAckPacket), OAck(TftpOAckPacket), @@ -99,9 +102,13 @@ pub enum TftpPacket { } impl TftpPacket { + pub fn parse(buf: &[u8]) -> Result { + parse_packet(buf) + } + pub fn write(&self, writer: W) -> Result<()> { match self { - TftpPacket::Request(p) => p.write(writer), + TftpPacket::ReadRequest(p) => p.write(writer), TftpPacket::Data(p) => p.write(writer), TftpPacket::Ack(p) => p.write(writer), TftpPacket::OAck(p) => p.write(writer), @@ -111,14 +118,14 @@ impl TftpPacket { } #[derive(Debug)] -pub struct TftpRequestPacket { +pub struct TftpReadRequestPacket { pub filename: String, pub mode: TftpMode, pub tsize: Option, pub blksize: Option, } -impl TftpRequestPacket { +impl TftpReadRequestPacket { pub fn write(&self, mut writer: W) -> Result<()> { todo!() } @@ -183,16 +190,85 @@ impl TftpOAckPacket { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum TftpErrorCode { + Undefined, + FileNotFound, + AccessViolation, + DiskFull, + IllegalOperation, + UnknownTransferId, + FileAreadyExists, + NoSuchUser, + Unknown(u16), +} + +impl std::fmt::Display for TftpErrorCode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?} ({})", *self, u16::from(*self)) + } +} + +impl TftpErrorCode { + pub const CODE_UNDEFINED: u16 = 0; + pub const CODE_FILE_NOT_FOUND: u16 = 1; + pub const CODE_ACCESS_VIOLATION: u16 = 2; + pub const CODE_DISK_FULL: u16 = 3; + pub const CODE_ILLEGAL_OPERATION: u16 = 4; + pub const CODE_UNKNOWN_TRANSFER_ID: u16 = 5; + pub const CODE_FILE_ALREADY_EXISTS: u16 = 6; + pub const CODE_NO_SUCH_USER: u16 = 7; +} + +impl From for TftpErrorCode { + fn from(value: u16) -> Self { + match value { + Self::CODE_UNDEFINED => Self::Undefined, + Self::CODE_FILE_NOT_FOUND => Self::FileNotFound, + Self::CODE_ACCESS_VIOLATION => Self::AccessViolation, + Self::CODE_DISK_FULL => Self::DiskFull, + Self::CODE_ILLEGAL_OPERATION => Self::IllegalOperation, + Self::CODE_UNKNOWN_TRANSFER_ID => Self::UnknownTransferId, + Self::CODE_FILE_ALREADY_EXISTS => Self::FileAreadyExists, + Self::CODE_NO_SUCH_USER => Self::NoSuchUser, + unknown => Self::Unknown(unknown), + } + } +} + +impl From for u16 { + fn from(value: TftpErrorCode) -> Self { + match value { + TftpErrorCode::Undefined => TftpErrorCode::CODE_UNDEFINED, + TftpErrorCode::FileNotFound => TftpErrorCode::CODE_FILE_NOT_FOUND, + TftpErrorCode::AccessViolation => TftpErrorCode::CODE_ACCESS_VIOLATION, + TftpErrorCode::DiskFull => TftpErrorCode::CODE_DISK_FULL, + TftpErrorCode::IllegalOperation => TftpErrorCode::CODE_ILLEGAL_OPERATION, + TftpErrorCode::UnknownTransferId => TftpErrorCode::CODE_UNKNOWN_TRANSFER_ID, + TftpErrorCode::FileAreadyExists => TftpErrorCode::CODE_FILE_ALREADY_EXISTS, + TftpErrorCode::NoSuchUser => TftpErrorCode::CODE_NO_SUCH_USER, + TftpErrorCode::Unknown(code) => code, + } + } +} + #[derive(Debug)] pub struct TftpErrorPacket { - pub code: u16, + pub code: TftpErrorCode, pub message: String, } impl TftpErrorPacket { + pub fn new(code: TftpErrorCode, message: impl Into) -> Self { + Self { + code, + message: message.into(), + } + } + pub fn write(&self, mut writer: W) -> Result<()> { wire::write_u16(&mut writer, TftpOp::Error.into())?; - wire::write_u16(&mut writer, self.code)?; + wire::write_u16(&mut writer, u16::from(self.code))?; wire::write_null_terminated_string(&mut writer, &self.message)?; Ok(()) } @@ -224,7 +300,7 @@ pub fn parse_packet(buf: &[u8]) -> Result { } } - Ok(TftpPacket::Request(TftpRequestPacket { + Ok(TftpPacket::ReadRequest(TftpReadRequestPacket { filename, mode, tsize, @@ -243,7 +319,7 @@ pub fn parse_packet(buf: &[u8]) -> Result { Ok(TftpPacket::Ack(TftpAckPacket { block })) } TftpOp::Error => { - let code = wire::read_u16(&mut cursor)?; + let code = TftpErrorCode::from(wire::read_u16(&mut cursor)?); let message = wire::read_null_terminated_string(&mut cursor)?; Ok(TftpPacket::Error(TftpErrorPacket { code, message })) } @@ -267,101 +343,222 @@ pub fn parse_packet(buf: &[u8]) -> Result { } } -pub fn serve(dir: impl AsRef) -> Result<()> { - let dir = dir.as_ref(); - let socket = UdpSocket::bind(format!("0.0.0.0:{PORT}"))?; +pub trait FileSystem { + fn stat(&self, filename: &str) -> Result; + fn read(&self, filename: &str, offset: u64, buf: &mut [u8]) -> Result; +} - // TODO: this needs to be done per addr - let mut last_blksize = 512u64; - let mut current_file = PathBuf::default(); +#[derive(Debug)] +pub struct StaticFileSystem { + files: &'static [(&'static str, &'static [u8])], +} + +impl StaticFileSystem { + pub fn new(files: &'static [(&'static str, &'static [u8])]) -> Self { + Self { files } + } - loop { - let mut buf = [0u8; 1500]; - let (n, addr) = socket.recv_from(&mut buf)?; - let packet = parse_packet(&buf[..n]).unwrap(); + fn find_file(&self, filename: &str) -> Result<&'static [u8]> { + self.files + .iter() + .find(|(name, _)| *name == filename) + .map(|(_, contents)| *contents) + .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::NotFound, "file not found")) + } +} + +impl FileSystem for StaticFileSystem { + fn stat(&self, filename: &str) -> Result { + let file = self.find_file(filename)?; + Ok(u64::try_from(file.len()).unwrap()) + } + + fn read(&self, filename: &str, offset: u64, buf: &mut [u8]) -> Result { + let file = self.find_file(filename)?; + let offset = usize::try_from(offset).unwrap(); + if offset >= file.len() { + return Ok(0); + } + + let rem = &file[offset..]; + let copy_n = rem.len().min(buf.len()); + buf[..copy_n].copy_from_slice(&rem[..copy_n]); + Ok(u64::try_from(copy_n).unwrap()) + } +} - let response = match packet { - TftpPacket::Request(req) => { +#[derive(Debug)] +struct Client { + blksize: u64, + filename: Option, +} + +impl Default for Client { + fn default() -> Self { + Self { + blksize: DEFAULT_BLOCK_SIZE, + filename: None, + } + } +} + +pub enum ServerCommand { + Send(TftpPacket), + Ignore, +} + +impl ServerCommand { + fn error(code: TftpErrorCode, message: impl Into) -> Self { + Self::Send(TftpPacket::Error(TftpErrorPacket::new(code, message))) + } +} + +#[derive(Debug, Default)] +pub struct Server { + clients: HashMap, +} + +impl Server { + pub fn process( + &mut self, + fs: &dyn FileSystem, + source: SocketAddr, + buf: &[u8], + ) -> ServerCommand { + let packet = match TftpPacket::parse(buf) { + Ok(packet) => packet, + Err(err) => { + return ServerCommand::Send(TftpPacket::Error(TftpErrorPacket::new( + TftpErrorCode::Undefined, + format!("invalid packet: {err}"), + ))); + } + }; + + match packet { + TftpPacket::ReadRequest(req) => self.process_read_req(fs, source, &req), + TftpPacket::Ack(ack) => self.process_ack(fs, source, &ack), + TftpPacket::Error(err) => { println!( - "Request options: tsize={:?}, blksize={:?}", - req.tsize, req.blksize + "received error from client {}: ({}) {}", + source, err.code, err.message ); + self.clients.remove(&source); + ServerCommand::Ignore + } + TftpPacket::Data(_) | TftpPacket::OAck(_) => ServerCommand::Ignore, + } + } - let filepath = dir.join(req.filename); - current_file = filepath.clone(); - let meta = std::fs::metadata(&filepath).unwrap(); - let actual_file_size = meta.len(); + fn process_read_req( + &mut self, + fs: &dyn FileSystem, + source: SocketAddr, + req: &TftpReadRequestPacket, + ) -> ServerCommand { + println!( + "Request options: tsize={:?}, blksize={:?}", + req.tsize, req.blksize + ); + + let client = self.clients.entry(source).or_default(); + client.filename = Some(req.filename.clone()); + + // Only send OACK if client sent options + if req.tsize.is_some() || req.blksize.is_some() { + if let Some(blksize) = req.blksize { + client.blksize = blksize; + } - // Only send OACK if client sent options - if req.tsize.is_some() || req.blksize.is_some() { - if let Some(blksize) = req.blksize { - last_blksize = blksize; + let tsize_response = if req.tsize.is_some() { + let filesize = match fs.stat(&req.filename) { + Ok(filesize) => filesize, + Err(err) => { + return ServerCommand::error( + TftpErrorCode::Undefined, + format!("failed to obtain file size: {}", err), + ); } + }; - let tsize_response = if req.tsize.is_some() { - Some(actual_file_size) - } else { - None - }; - - Some(TftpPacket::OAck(TftpOAckPacket { - tsize: tsize_response, - blksize: req.blksize, - })) - } else { - // No options, send first data block directly - let contents = std::fs::read(&filepath).unwrap(); - let block_size = 512; - let first_block = if contents.len() > block_size { - contents[..block_size].to_vec() - } else { - contents - }; - - Some(TftpPacket::Data(TftpDataPacket { - block: 1, - data: first_block, - })) - } - } - TftpPacket::Data(dat) => unimplemented!(), - TftpPacket::Ack(ack) => { - println!("Received ACK packet: block {}", ack.block); - - let contents = std::fs::read(¤t_file).unwrap(); - let next_block = ack.block + 1; - let start_offset = (next_block - 1) as u64 * last_blksize; - let end_offset = next_block as u64 * last_blksize; - let prev_start_offset = (next_block.saturating_sub(2)) as u64 * last_blksize; - let prev_remain = contents.len() - prev_start_offset as usize; - if prev_remain as u64 >= last_blksize || ack.block == 0 { - let end = std::cmp::min(end_offset as usize, contents.len()); - let block_data = contents[start_offset as usize..end].to_vec(); - println!("sending tftp data packet with {} bytes", block_data.len()); - Some(TftpPacket::Data(TftpDataPacket { - block: next_block, - data: block_data, - })) - } else { - None + Some(filesize) + } else { + None + }; + + ServerCommand::Send(TftpPacket::OAck(TftpOAckPacket { + tsize: tsize_response, + blksize: req.blksize, + })) + } else { + // No options, send first data block directly + let options = self.clients.entry(source).or_default(); + let block_size = usize::try_from(options.blksize).unwrap(); + + assert!(block_size <= MAX_BLOCK_SIZE); + let mut contents = [0u8; MAX_BLOCK_SIZE]; + let contents = &mut contents[..block_size]; + + let n = match fs.read(&req.filename, 0, contents) { + Ok(n) => usize::try_from(n).unwrap(), + Err(err) => { + return ServerCommand::error( + TftpErrorCode::Undefined, + format!("failed to read file contents: {}", err), + ); } + }; + + ServerCommand::Send(TftpPacket::Data(TftpDataPacket { + block: 1, + data: contents[..n].to_vec(), + })) + } + } + + fn process_ack( + &mut self, + fs: &dyn FileSystem, + source: SocketAddr, + ack: &TftpAckPacket, + ) -> ServerCommand { + println!("Received ACK packet: block {}", ack.block); + + let client = self.clients.entry(source).or_default(); + let filename = match &client.filename { + Some(filename) => filename, + None => { + return ServerCommand::error( + TftpErrorCode::Undefined, + "unknown filename for client", + ); } - TftpPacket::OAck(ack) => todo!(), - TftpPacket::Error(err) => { - println!( - "Received ERROR packet: code {}, message: {}", - err.code, err.message + }; + let block_size = usize::try_from(client.blksize).unwrap(); + + let next_block = ack.block + 1; + let start_offset = (usize::from(next_block) - 1) * block_size; + + let mut contents = [0u8; MAX_BLOCK_SIZE]; + let contents = &mut contents[..block_size]; + let n = match fs.read(filename, u64::try_from(start_offset).unwrap(), contents) { + Ok(n) => usize::try_from(n).unwrap(), + Err(err) => { + return ServerCommand::error( + TftpErrorCode::Undefined, + format!("failed to read file contents: {}", err), ); - None } }; + let contents = &contents[..n]; - if let Some(response) = response { - let mut writer = Cursor::new(&mut buf[..]); - println!("Sending to {addr}: {response:#?}"); - response.write(&mut writer).unwrap(); - let (response, _) = writer.split(); - socket.send_to(&response, addr).unwrap(); + if contents.is_empty() { + return ServerCommand::Ignore; } + + ServerCommand::Send(TftpPacket::Data(TftpDataPacket { + block: next_block, + data: contents.to_vec(), + })) } } -- cgit