#![feature(gethostname)] #![feature(cursor_split)] pub mod dhcp; pub mod tftp; pub mod wire; use std::{ io::Result, net::{Ipv4Addr, SocketAddr, SocketAddrV4, UdpSocket}, }; use clap::Parser; use ipnet::Ipv4Net; use crate::dhcp::{DhcpOption, DhcpPacket}; const BOOT_FILE_X64_BIOS: &str = "netboot.xyz.kpxe"; const BOOT_FILE_X64_EFI: &str = "netboot.xyz.efi"; const BOOT_FILE_A64_EFI: &str = "netboot.xyz-arm64.efi"; const MENU_FILE: &str = "menu.ipxe"; #[derive(Debug, Parser)] struct Cli { #[clap(long)] hostname: Option, #[clap(long)] interface: Option, #[clap(long, default_value = "0.0.0.0")] listen_address: Ipv4Addr, #[clap(long, default_value = "67")] dhcp_port: u16, #[clap(long, default_value = "4011")] proxy_dhcp_port: u16, #[clap(long, default_value = "69")] tftp_port: u16, } struct Context { local_hostname: String, local_address: Ipv4Addr, } fn main() { let cli = Cli::parse(); let dhcp_sockaddr = SocketAddrV4::new(cli.listen_address, cli.dhcp_port); let pdhcp_sockaddr = SocketAddrV4::new(cli.listen_address, cli.proxy_dhcp_port); let tftp_sockaddr = SocketAddrV4::new(cli.listen_address, cli.tftp_port); let hostname = match cli.hostname { Some(hostname) => hostname, None => { let hostname = std::net::hostname().expect("unable to obtain local machine's hostname"); hostname .into_string() .expect("unable to convert local machine's hostname to utf-8 string") } }; let local_ip_address = if let Some(interface_name) = cli.interface { let interfaces = list_network_interfaces().expect("unable to list network interfaces"); let interface = interfaces .iter() .find(|i| i.interface == interface_name) .expect("interface not found"); interface.address } else if cli.listen_address == Ipv4Addr::UNSPECIFIED { let interfaces = list_network_interfaces().expect("unable to list network interfaces"); let mut chosen = None; for interface in interfaces { if interface.address.is_loopback() { continue; } chosen = Some((interface.interface, interface.address)); break; } let (name, addr) = chosen.expect("unable to find network interface with non-loopback IPv4 address"); println!("using local address {} from interface {}", addr, name); addr } else { cli.listen_address }; println!("local hostname = {hostname}"); println!("local address = {local_ip_address}"); let context = Context { local_hostname: hostname, local_address: local_ip_address, }; let socket_dhcp = UdpSocket::bind(dhcp_sockaddr).unwrap(); socket_dhcp.set_broadcast(true).unwrap(); socket_dhcp.set_nonblocking(true).unwrap(); let socket_pdhcp = UdpSocket::bind(pdhcp_sockaddr).unwrap(); socket_pdhcp.set_broadcast(true).unwrap(); socket_pdhcp.set_nonblocking(true).unwrap(); let socket_tftp = UdpSocket::bind(tftp_sockaddr).unwrap(); socket_tftp.set_broadcast(false).unwrap(); socket_tftp.set_nonblocking(true).unwrap(); let tftp_filesystem = tftp::StaticFileSystem::new(&[ (BOOT_FILE_X64_BIOS, include_bytes!("../netboot.xyz.kpxe")), (BOOT_FILE_X64_EFI, include_bytes!("../netboot.xyz.efi")), ( BOOT_FILE_A64_EFI, include_bytes!("../netboot.xyz-arm64.efi"), ), (MENU_FILE, include_bytes!("../menu.ipxe")), ]); let mut tftp_server = tftp::Server::default(); loop { let mut buf = [0u8; 1500]; if let Ok((n, addr)) = socket_dhcp.recv_from(&mut buf) { println!("Received {} bytes from {} on port 67", n, addr); handle_packet(&context, &buf[..n], &socket_dhcp); } if let Ok((n, addr)) = socket_pdhcp.recv_from(&mut buf) { println!("Received {} bytes from {} on port 4011", n, addr); handle_packet_4011(&context, &buf[..n], &socket_pdhcp, addr); } if let Ok((n, addr)) = socket_tftp.recv_from(&mut buf) { println!("Received {} bytes from {} on port 4011", n, addr); match tftp_server.process(&tftp_filesystem, addr, &buf) { tftp::ServerCommand::Send(tftp_packet) => { let mut output = Vec::default(); tftp_packet.write(&mut output).unwrap(); socket_tftp.send_to(&output, addr).unwrap(); } tftp::ServerCommand::Ignore => {} } } std::thread::sleep(std::time::Duration::from_millis(1)); } } fn handle_packet(context: &Context, buf: &[u8], socket: &UdpSocket) { let packet = match dhcp::parse_packet(buf) { Ok(packet) => packet, Err(err) => { eprintln!("failed to parse DHCP packet: {err}"); return; } }; println!("Parsed DHCP packet: XID={:08x}", packet.xid); // Check if it's a PXE client and extract client UUID let mut pxe_class = None; let mut client_uuid = None; let mut is_ipxe = false; for option in &packet.options { match option { DhcpOption::VendorClassIdentifier(vendor_class) => { if let Ok(class) = dhcp::PxeClassIdentifier::try_from(vendor_class.as_slice()) { println!("{class}"); pxe_class = Some(class); } } DhcpOption::UserClassInformation(user_class) => { if user_class == dhcp::USER_CLASS_IPXE { is_ipxe = true; } } DhcpOption::ClientMachineIdentifier(uuid) => { client_uuid = Some(uuid.clone()); } _ => {} } } let pxe_client_class = match pxe_class { Some(dhcp::PxeClassIdentifier::Client(class)) => class, _ => { println!("Not a PXE client, ignoring"); return; } }; println!("Responding to PXE client with DHCPOFFER"); let mut response_buf = Vec::default(); let response = DhcpPacket::new_boot( packet.xid, packet.chaddr, client_uuid, context.local_address, context.local_hostname.clone(), match is_ipxe { true => MENU_FILE.to_string(), false => match pxe_client_class.architecture { dhcp::SystemArchitecture::IntelX86pc => BOOT_FILE_X64_BIOS.to_string(), dhcp::SystemArchitecture::EfiARM64 => BOOT_FILE_A64_EFI.to_string(), dhcp::SystemArchitecture::EfiX86_64 | dhcp::SystemArchitecture::EfiBC => { BOOT_FILE_X64_EFI.to_string() } _ => { eprintln!( "unsupported architecture {:?}", pxe_client_class.architecture ); return; } }, }, ); response.write(&mut response_buf).unwrap(); socket .send_to(&response_buf, SocketAddrV4::new(Ipv4Addr::BROADCAST, 68)) .unwrap(); } fn handle_packet_4011(context: &Context, buf: &[u8], socket: &UdpSocket, sender_addr: SocketAddr) { let packet = match dhcp::parse_packet(buf) { Ok(packet) => packet, Err(err) => { println!("Failed to parse packet on 4011: {}", err); return; } }; println!("Parsed DHCP packet on 4011: XID={:08x}", packet.xid); // Extract client UUID let mut client_uuid = None; for option in &packet.options { if let DhcpOption::ClientMachineIdentifier(uuid) = option { client_uuid = Some(uuid.clone()); break; } } let mut client_class = None; for option in &packet.options { if let DhcpOption::VendorClassIdentifier(vendor_class) = option && let Ok(dhcp::PxeClassIdentifier::Client(class)) = dhcp::PxeClassIdentifier::try_from(vendor_class.as_slice()) { println!("{class}"); client_class = Some(class); } } let client_class = match client_class { Some(class) => class, None => return, }; let file = match client_class.architecture { dhcp::SystemArchitecture::IntelX86pc => BOOT_FILE_X64_BIOS.to_string(), dhcp::SystemArchitecture::EfiARM64 => BOOT_FILE_A64_EFI.to_string(), dhcp::SystemArchitecture::EfiX86_64 | dhcp::SystemArchitecture::EfiBC => { BOOT_FILE_X64_EFI.to_string() } _ => { eprintln!("unsupported architecture {:?}", client_class.architecture); return; } }; println!("Responding with DHCPACK"); let mut response_buf = Vec::default(); let response = DhcpPacket::new_boot_ack( packet.xid, packet.chaddr, client_uuid, context.local_address, context.local_hostname.clone(), file, ); response.write(&mut response_buf).unwrap(); socket.send_to(&response_buf, sender_addr).unwrap(); } #[allow(unused)] #[derive(Debug, Clone)] struct InterfaceAddr { interface: String, address: Ipv4Addr, network: Ipv4Net, } fn list_network_interfaces() -> Result> { unsafe { let mut ifap: *mut libc::ifaddrs = std::ptr::null_mut(); if libc::getifaddrs(&mut ifap) != 0 { return Err(std::io::Error::last_os_error()); } let mut interfaces = Vec::new(); let mut current = ifap; while !current.is_null() { let ifa = &*current; if !ifa.ifa_addr.is_null() { let addr_family = (*ifa.ifa_addr).sa_family; let name = std::ffi::CStr::from_ptr(ifa.ifa_name) .to_string_lossy() .into_owned(); if addr_family as i32 == libc::AF_INET { let addr = ifa.ifa_addr as *const libc::sockaddr_in; let mask = ifa.ifa_netmask as *const libc::sockaddr_in; let addr = Ipv4Addr::from((*addr).sin_addr.s_addr.to_ne_bytes()); let mask = Ipv4Addr::from((*mask).sin_addr.s_addr.to_ne_bytes()); let network = Ipv4Net::with_netmask(addr, mask).unwrap().trunc(); interfaces.push(InterfaceAddr { interface: name, address: addr, network, }); } } current = ifa.ifa_next; } libc::freeifaddrs(ifap); Ok(interfaces) } }