From c0148cb62800789e94ef41e34bee53e58fac02f2 Mon Sep 17 00:00:00 2001 From: diogo464 Date: Tue, 7 Oct 2025 10:34:12 +0100 Subject: split some code into modules --- src/dhcp.rs | 214 +++++++++++++++++++++++++++++++++++ src/main.rs | 7 +- src/tftp.rs | 365 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ src/wire.rs | 81 ++++++++++++++ 4 files changed, 666 insertions(+), 1 deletion(-) create mode 100644 src/dhcp.rs create mode 100644 src/tftp.rs create mode 100644 src/wire.rs (limited to 'src') diff --git a/src/dhcp.rs b/src/dhcp.rs new file mode 100644 index 0000000..b53ee92 --- /dev/null +++ b/src/dhcp.rs @@ -0,0 +1,214 @@ +use std::{ + io::{Result, Write}, + net::Ipv4Addr, +}; + +use crate::wire; + +const MAGIC_COOKIE: [u8; 4] = [0x63, 0x82, 0x53, 0x63]; + +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)] +pub enum BootOp { + #[default] + Request, + Reply, +} + +impl BootOp { + pub const OP_REQUEST: u8 = 1; + pub const OP_REPLY: u8 = 2; +} + +impl From for u8 { + fn from(value: BootOp) -> Self { + match value { + BootOp::Request => BootOp::OP_REQUEST, + BootOp::Reply => BootOp::OP_REPLY, + } + } +} + +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)] +pub enum HardwareType { + #[default] + Ethernet, +} + +impl HardwareType { + pub const TYPE_ETHER: u8 = 1; + pub const LEN_ETHER: u8 = 6; + + pub fn hardware_len(&self) -> u8 { + match self { + HardwareType::Ethernet => Self::LEN_ETHER, + } + } +} + +impl From for u8 { + fn from(value: HardwareType) -> Self { + match value { + HardwareType::Ethernet => HardwareType::TYPE_ETHER, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum DhcpMessageType { + Ack, +} + +impl DhcpMessageType { + pub const CODE_ACK: u8 = 5; + + pub fn code(&self) -> u8 { + match self { + DhcpMessageType::Ack => Self::CODE_ACK, + } + } +} + +#[derive(Debug, Clone)] +pub enum DhcpOption { + Pad, + End, + MessageType(DhcpMessageType), + ServerIdentifier(Ipv4Addr), + VendorClassIdentifier(String), + TftpServerName(String), + TftpFileName(String), + UserClassInformation(String), + ClientMachineIdentifier(String), + Unknown { code: u8, data: Vec }, +} + +impl DhcpOption { + pub const CODE_PAD: u8 = 0; + pub const CODE_END: u8 = 255; + pub const CODE_DHCP_MESSAGE_TYPE: u8 = 53; + pub const CODE_DHCP_SERVER_IDENTIFIER: u8 = 54; + pub const CODE_VENDOR_CLASS_IDENTIFIER: u8 = 60; + pub const CODE_TFTP_SERVER_NAME: u8 = 66; + pub const CODE_TFTP_FILE_NAME: u8 = 67; + pub const CODE_USER_CLASS_INFORMATION: u8 = 77; + pub const CODE_CLIENT_MACHINE_IDENTIFIER: u8 = 97; + + pub fn code(&self) -> u8 { + match self { + DhcpOption::Pad => Self::CODE_PAD, + DhcpOption::End => Self::CODE_END, + DhcpOption::MessageType(_) => Self::CODE_DHCP_MESSAGE_TYPE, + DhcpOption::ServerIdentifier(_) => Self::CODE_DHCP_SERVER_IDENTIFIER, + DhcpOption::VendorClassIdentifier(_) => Self::CODE_VENDOR_CLASS_IDENTIFIER, + DhcpOption::TftpServerName(_) => Self::CODE_TFTP_SERVER_NAME, + DhcpOption::TftpFileName(_) => Self::CODE_TFTP_FILE_NAME, + DhcpOption::UserClassInformation(_) => Self::CODE_USER_CLASS_INFORMATION, + DhcpOption::ClientMachineIdentifier(_) => Self::CODE_CLIENT_MACHINE_IDENTIFIER, + DhcpOption::Unknown { code, .. } => *code, + } + } +} + +#[derive(Debug)] +pub struct DhcpPacket { + pub op: BootOp, + pub htype: HardwareType, + pub xid: u32, + pub secs: u16, + pub flags: u16, + pub ciaddr: Ipv4Addr, + pub yiaddr: Ipv4Addr, + pub siaddr: Ipv4Addr, + pub giaddr: Ipv4Addr, + pub chaddr: [u8; 16], + // server host name + pub sname: Option, + // boot file name + pub file: Option, + pub options: Vec, +} + +impl Default for DhcpPacket { + fn default() -> Self { + Self { + op: Default::default(), + htype: Default::default(), + xid: Default::default(), + secs: Default::default(), + flags: Default::default(), + ciaddr: Ipv4Addr::UNSPECIFIED, + yiaddr: Ipv4Addr::UNSPECIFIED, + siaddr: Ipv4Addr::UNSPECIFIED, + giaddr: Ipv4Addr::UNSPECIFIED, + chaddr: Default::default(), + sname: Default::default(), + file: Default::default(), + options: Default::default(), + } + } +} + +pub fn write_packet(mut writer: W, packet: &DhcpPacket) -> Result<()> { + wire::write_u8(&mut writer, u8::from(packet.op))?; + wire::write_u8(&mut writer, u8::from(packet.htype))?; + wire::write_u8(&mut writer, packet.htype.hardware_len())?; + wire::write_u8(&mut writer, 0)?; // hops + wire::write_u32(&mut writer, packet.xid)?; + wire::write_u16(&mut writer, packet.secs)?; + wire::write_u16(&mut writer, packet.flags)?; + wire::write_ipv4(&mut writer, packet.ciaddr)?; + wire::write_ipv4(&mut writer, packet.yiaddr)?; + wire::write_ipv4(&mut writer, packet.siaddr)?; + wire::write_ipv4(&mut writer, packet.giaddr)?; + wire::write(&mut writer, &packet.chaddr)?; + match &packet.sname { + Some(name) => wire::write_null_terminated_string(&mut writer, &name)?, + None => wire::write_null_terminated_string(&mut writer, "")?, + }; + match &packet.file { + Some(name) => wire::write_null_terminated_string(&mut writer, &name)?, + None => wire::write_null_terminated_string(&mut writer, "")?, + }; + wire::write(&mut writer, &MAGIC_COOKIE)?; + for option in &packet.options { + write_option(&mut writer, option)?; + } + write_option(&mut writer, &DhcpOption::End)?; + Ok(()) +} + +pub fn write_option(mut writer: W, option: &DhcpOption) -> Result<()> { + wire::write_u8(&mut writer, option.code())?; + match option { + DhcpOption::Pad | DhcpOption::End => {} + DhcpOption::MessageType(t) => { + wire::write_u8(&mut writer, 1)?; + wire::write_u8(&mut writer, t.code())?; + } + DhcpOption::ServerIdentifier(ip) => { + wire::write_u8(&mut writer, 4)?; + wire::write_ipv4(&mut writer, *ip)?; + } + DhcpOption::VendorClassIdentifier(vendor_class) => { + write_option_len_prefixed_string(&mut writer, &vendor_class)? + } + DhcpOption::TftpServerName(name) => write_option_len_prefixed_string(&mut writer, &name)?, + DhcpOption::TftpFileName(name) => write_option_len_prefixed_string(&mut writer, &name)?, + DhcpOption::UserClassInformation(user_class) => { + write_option_len_prefixed_string(&mut writer, &user_class)? + } + DhcpOption::ClientMachineIdentifier(identifier) => { + write_option_len_prefixed_string(&mut writer, &identifier)? + } + DhcpOption::Unknown { data, .. } => { + wire::write_u8(&mut writer, u8::try_from(data.len()).unwrap())?; + wire::write(&mut writer, &data)?; + } + } + Ok(()) +} + +fn write_option_len_prefixed_string(mut writer: W, s: &str) -> Result<()> { + wire::write_u8(&mut writer, u8::try_from(s.len()).unwrap())?; + wire::write(&mut writer, s.as_bytes()) +} diff --git a/src/main.rs b/src/main.rs index 87ea283..3d86a8e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,8 @@ +#![feature(cursor_split)] +pub mod dhcp; +pub mod tftp; +pub mod wire; + use std::io::{BufRead, Cursor, Read, Result, Write}; use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4, UdpSocket}; @@ -15,7 +20,7 @@ const MAGIC_COOKIE: [u8; 4] = [0x63, 0x82, 0x53, 0x63]; const BOOT_FILE_NAME: &[u8] = b"ipxe.efi"; const BOOT_FILE_NAME_IPXE: &[u8] = b"test.ipxe"; -const LOCAL_IPV4: Ipv4Addr = Ipv4Addr::new(192, 168, 2, 184); +const LOCAL_IPV4: Ipv4Addr = Ipv4Addr::new(192, 168, 1, 100); #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] enum BootOp { diff --git a/src/tftp.rs b/src/tftp.rs new file mode 100644 index 0000000..becaa65 --- /dev/null +++ b/src/tftp.rs @@ -0,0 +1,365 @@ +use std::{ + io::{Cursor, Read as _, Result, Write}, + net::UdpSocket, + path::{Path, PathBuf}, + str::FromStr, +}; + +use crate::wire; + +pub const PORT: u16 = 69; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct InvalidTftpOp(u16); + +impl std::fmt::Display for InvalidTftpOp { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "invalid tftp opcode '{}'", self.0) + } +} + +impl std::error::Error for InvalidTftpOp {} +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum TftpOp { + ReadRequest, + WriteRequest, + Data, + Ack, + Error, + Oack, +} + +impl Into for TftpOp { + fn into(self) -> u16 { + match self { + TftpOp::ReadRequest => 1, + TftpOp::WriteRequest => 2, + TftpOp::Data => 3, + TftpOp::Ack => 4, + TftpOp::Error => 5, + TftpOp::Oack => 6, + } + } +} + +impl TryFrom for TftpOp { + type Error = InvalidTftpOp; + + fn try_from(value: u16) -> std::result::Result { + match value { + 1 => Ok(Self::ReadRequest), + 2 => Ok(Self::WriteRequest), + 3 => Ok(Self::Data), + 4 => Ok(Self::Ack), + 5 => Ok(Self::Error), + 6 => Ok(Self::Oack), + unknown => Err(InvalidTftpOp(unknown)), + } + } +} + +#[derive(Debug)] +pub struct InvalidTftpMode(String); + +impl std::fmt::Display for InvalidTftpMode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "invalid tftp mode '{}'", self.0) + } +} + +impl std::error::Error for InvalidTftpMode {} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum TftpMode { + NetAscii, + Octet, + Mail, +} + +impl FromStr for TftpMode { + type Err = InvalidTftpMode; + + fn from_str(s: &str) -> std::result::Result { + match s.to_lowercase().as_str() { + "netascii" => Ok(Self::NetAscii), + "octet" => Ok(Self::Octet), + "mail" => Ok(Self::Mail), + _ => Err(InvalidTftpMode(s.to_string())), + } + } +} + +#[derive(Debug)] +pub enum TftpPacket { + Request(TftpRequestPacket), + Data(TftpDataPacket), + Ack(TftpAckPacket), + OAck(TftpOAckPacket), + Error(TftpErrorPacket), +} + +impl TftpPacket { + pub fn write(&self, writer: W) -> Result<()> { + match self { + TftpPacket::Request(p) => p.write(writer), + TftpPacket::Data(p) => p.write(writer), + TftpPacket::Ack(p) => p.write(writer), + TftpPacket::OAck(p) => p.write(writer), + TftpPacket::Error(p) => p.write(writer), + } + } +} + +#[derive(Debug)] +pub struct TftpRequestPacket { + pub filename: String, + pub mode: TftpMode, + pub tsize: Option, + pub blksize: Option, +} + +impl TftpRequestPacket { + pub fn write(&self, mut writer: W) -> Result<()> { + todo!() + } +} + +#[derive(Debug)] +pub struct TftpDataPacket { + pub block: u16, + pub data: Vec, +} + +impl TftpDataPacket { + pub fn write(&self, mut writer: W) -> Result<()> { + wire::write_u16(&mut writer, TftpOp::Data.into())?; + wire::write_u16(&mut writer, self.block)?; + wire::write(&mut writer, &self.data)?; + Ok(()) + } +} + +#[derive(Debug)] +pub struct TftpAckPacket { + pub block: u16, +} + +impl TftpAckPacket { + pub fn write(&self, mut writer: W) -> Result<()> { + wire::write_u16(&mut writer, TftpOp::Data.into())?; + wire::write_u16(&mut writer, self.block)?; + Ok(()) + } +} + +#[derive(Debug)] +pub struct TftpOAckPacket { + pub tsize: Option, + pub blksize: Option, +} + +impl TftpOAckPacket { + pub fn write(&self, mut writer: W) -> Result<()> { + wire::write_u16(&mut writer, TftpOp::Oack.into())?; + + // Only include options that were requested by the client + if let Some(blksize_val) = self.blksize { + wire::write(&mut writer, b"blksize")?; + wire::write_u8(&mut writer, 0)?; // null terminator + let blksize_str = blksize_val.to_string(); + wire::write(&mut writer, blksize_str.as_bytes())?; + wire::write_u8(&mut writer, 0)?; // null terminator + } + + if let Some(tsize_val) = self.tsize { + wire::write(&mut writer, b"tsize")?; + wire::write_u8(&mut writer, 0)?; // null terminator + let tsize_str = tsize_val.to_string(); + wire::write(&mut writer, tsize_str.as_bytes())?; + wire::write_u8(&mut writer, 0)?; // null terminator + } + + Ok(()) + } +} + +#[derive(Debug)] +pub struct TftpErrorPacket { + pub code: u16, + pub message: String, +} + +impl TftpErrorPacket { + pub fn write(&self, mut writer: W) -> Result<()> { + wire::write_u16(&mut writer, TftpOp::Error.into())?; + wire::write_u16(&mut writer, self.code)?; + wire::write_null_terminated_string(&mut writer, &self.message)?; + Ok(()) + } +} + +pub fn parse_packet(buf: &[u8]) -> Result { + let mut cursor = Cursor::new(buf); + let op = TftpOp::try_from(wire::read_u16(&mut cursor)?).unwrap(); + + match op { + TftpOp::ReadRequest => { + let filename = wire::read_null_terminated_string(&mut cursor)?; + let mode = wire::read_null_terminated_string(&mut cursor)? + .parse::() + .unwrap(); + + let mut tsize = None; + let mut blksize = None; + + while let Ok(opt_name) = wire::read_null_terminated_string(&mut cursor) { + if opt_name.is_empty() { + break; + } + let opt_data = wire::read_null_terminated_string(&mut cursor)?; + match opt_name.as_str() { + "tsize" => tsize = Some(opt_data.parse::().unwrap()), + "blksize" => blksize = Some(opt_data.parse::().unwrap()), + _ => eprintln!("unknown tftp request option '{opt_name}'"), + } + } + + Ok(TftpPacket::Request(TftpRequestPacket { + filename, + mode, + tsize, + blksize, + })) + } + TftpOp::WriteRequest => unimplemented!(), + TftpOp::Data => { + let block = wire::read_u16(&mut cursor)?; + let mut data = Vec::new(); + cursor.read_to_end(&mut data)?; + Ok(TftpPacket::Data(TftpDataPacket { block, data })) + } + TftpOp::Ack => { + let block = wire::read_u16(&mut cursor)?; + Ok(TftpPacket::Ack(TftpAckPacket { block })) + } + TftpOp::Error => { + let code = wire::read_u16(&mut cursor)?; + let message = wire::read_null_terminated_string(&mut cursor)?; + Ok(TftpPacket::Error(TftpErrorPacket { code, message })) + } + TftpOp::Oack => { + let mut tsize = None; + let mut blksize = None; + + while let Ok(opt_name) = wire::read_null_terminated_string(&mut cursor) { + if opt_name.is_empty() { + break; + } + let opt_data = wire::read_null_terminated_string(&mut cursor)?; + match opt_name.as_str() { + "tsize" => tsize = Some(opt_data.parse::().unwrap()), + "blksize" => blksize = Some(opt_data.parse::().unwrap()), + _ => eprintln!("unknown tftp ack option '{opt_name}'"), + } + } + Ok(TftpPacket::OAck(TftpOAckPacket { tsize, blksize })) + } + } +} + +pub fn serve(dir: &Path) -> Result<()> { + let socket = UdpSocket::bind(format!("0.0.0.0:{PORT}"))?; + + // TODO: this needs to be done per addr + let mut last_blksize = 512u64; + let mut current_file = PathBuf::default(); + + loop { + let mut buf = [0u8; 1500]; + let (n, addr) = socket.recv_from(&mut buf)?; + let packet = parse_packet(&buf[..n]).unwrap(); + + let response = match packet { + TftpPacket::Request(req) => { + println!( + "Request options: tsize={:?}, blksize={:?}", + req.tsize, req.blksize + ); + + let filepath = dir.join(req.filename); + current_file = filepath.clone(); + let meta = std::fs::metadata(&filepath).unwrap(); + let actual_file_size = meta.len(); + + // Only send OACK if client sent options + if req.tsize.is_some() || req.blksize.is_some() { + if let Some(blksize) = req.blksize { + last_blksize = blksize; + } + + let tsize_response = if req.tsize.is_some() { + Some(actual_file_size) + } else { + None + }; + + Some(TftpPacket::OAck(TftpOAckPacket { + tsize: req.tsize, + blksize: req.blksize, + })) + } else { + // No options, send first data block directly + let contents = std::fs::read(&filepath).unwrap(); + let block_size = 512; + let first_block = if contents.len() > block_size { + contents[..block_size].to_vec() + } else { + contents + }; + + Some(TftpPacket::Data(TftpDataPacket { + block: 1, + data: first_block, + })) + } + } + TftpPacket::Data(dat) => unimplemented!(), + TftpPacket::Ack(ack) => { + println!("Received ACK packet: block {}", ack.block); + + let contents = std::fs::read(¤t_file).unwrap(); + let next_block = ack.block + 1; + let start_offset = (next_block - 1) as u64 * last_blksize; + let end_offset = next_block as u64 * last_blksize; + let prev_start_offset = (next_block.saturating_sub(2)) as u64 * last_blksize; + let prev_remain = contents.len() - prev_start_offset as usize; + if prev_remain as u64 >= last_blksize || ack.block == 0 { + let end = std::cmp::min(end_offset as usize, contents.len()); + let block_data = contents[start_offset as usize..end].to_vec(); + println!("sending tftp data packet with {} bytes", block_data.len()); + Some(TftpPacket::Data(TftpDataPacket { + block: next_block, + data: block_data, + })) + } else { + None + } + } + TftpPacket::OAck(ack) => todo!(), + TftpPacket::Error(err) => { + println!( + "Received ERROR packet: code {}, message: {}", + err.code, err.message + ); + None + } + }; + + if let Some(response) = response { + let mut writer = Cursor::new(&mut buf[..]); + response.write(&mut writer).unwrap(); + let (response, _) = writer.split(); + socket.send_to(&response, addr).unwrap(); + } + } +} diff --git a/src/wire.rs b/src/wire.rs new file mode 100644 index 0000000..dda7690 --- /dev/null +++ b/src/wire.rs @@ -0,0 +1,81 @@ +use std::{ + io::{BufRead, Read, Result, Write}, + net::Ipv4Addr, +}; + +pub fn write(mut writer: W, v: &[u8]) -> Result<()> { + writer.write_all(v) +} + +pub fn write_u8(mut writer: W, v: u8) -> Result<()> { + writer.write_all(&[v]) +} + +pub fn write_u16(mut writer: W, v: u16) -> Result<()> { + writer.write_all(&u16::to_be_bytes(v)) +} + +pub fn write_u32(mut writer: W, v: u32) -> Result<()> { + writer.write_all(&u32::to_be_bytes(v)) +} + +pub fn write_ipv4(mut writer: W, v: Ipv4Addr) -> Result<()> { + writer.write_all(&v.octets()) +} + +pub fn write_null_terminated_string(mut writer: W, v: &str) -> Result<()> { + writer.write_all(v.as_bytes())?; + writer.write_all(&[0u8]) +} + +pub fn read_u8(mut reader: R) -> Result { + let mut buf = [0u8; 1]; + reader.read_exact(&mut buf)?; + Ok(buf[0]) +} + +pub fn read_u16(mut reader: R) -> Result { + let mut buf = [0u8; 2]; + reader.read_exact(&mut buf)?; + Ok(u16::from_be_bytes(buf)) +} + +pub fn read_u32(mut reader: R) -> Result { + let mut buf = [0u8; 4]; + reader.read_exact(&mut buf)?; + Ok(u32::from_be_bytes(buf)) +} + +pub fn read_arr(mut reader: R) -> Result<[u8; N]> { + let mut buf = [0u8; N]; + reader.read_exact(&mut buf)?; + Ok(buf) +} + +pub fn read_null_terminated_vec(mut reader: R) -> Result> { + let mut buf = Vec::default(); + reader.read_until(0, &mut buf)?; + buf.pop(); + Ok(buf) +} + +pub fn read_null_terminated_string(reader: R) -> Result { + let buf = read_null_terminated_vec(reader)?; + Ok(String::from_utf8(buf).unwrap()) +} + +pub fn read_len8_prefixed_vec(mut reader: R) -> Result> { + let len = read_u8(&mut reader)?; + let mut buf = vec![0u8; len as usize]; + reader.read_exact(&mut buf)?; + Ok(buf) +} + +pub fn read_len8_prefixed_string(reader: R) -> Result { + let buf = read_len8_prefixed_vec(reader)?; + Ok(String::from_utf8(buf).unwrap()) +} + +pub fn read_ipv4(reader: R) -> Result { + Ok(Ipv4Addr::from_octets(read_arr(reader)?)) +} -- cgit