From c0148cb62800789e94ef41e34bee53e58fac02f2 Mon Sep 17 00:00:00 2001 From: diogo464 Date: Tue, 7 Oct 2025 10:34:12 +0100 Subject: split some code into modules --- src/tftp.rs | 365 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 365 insertions(+) create mode 100644 src/tftp.rs (limited to 'src/tftp.rs') diff --git a/src/tftp.rs b/src/tftp.rs new file mode 100644 index 0000000..becaa65 --- /dev/null +++ b/src/tftp.rs @@ -0,0 +1,365 @@ +use std::{ + io::{Cursor, Read as _, Result, Write}, + net::UdpSocket, + path::{Path, PathBuf}, + str::FromStr, +}; + +use crate::wire; + +pub const PORT: u16 = 69; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct InvalidTftpOp(u16); + +impl std::fmt::Display for InvalidTftpOp { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "invalid tftp opcode '{}'", self.0) + } +} + +impl std::error::Error for InvalidTftpOp {} +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum TftpOp { + ReadRequest, + WriteRequest, + Data, + Ack, + Error, + Oack, +} + +impl Into for TftpOp { + fn into(self) -> u16 { + match self { + TftpOp::ReadRequest => 1, + TftpOp::WriteRequest => 2, + TftpOp::Data => 3, + TftpOp::Ack => 4, + TftpOp::Error => 5, + TftpOp::Oack => 6, + } + } +} + +impl TryFrom for TftpOp { + type Error = InvalidTftpOp; + + fn try_from(value: u16) -> std::result::Result { + match value { + 1 => Ok(Self::ReadRequest), + 2 => Ok(Self::WriteRequest), + 3 => Ok(Self::Data), + 4 => Ok(Self::Ack), + 5 => Ok(Self::Error), + 6 => Ok(Self::Oack), + unknown => Err(InvalidTftpOp(unknown)), + } + } +} + +#[derive(Debug)] +pub struct InvalidTftpMode(String); + +impl std::fmt::Display for InvalidTftpMode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "invalid tftp mode '{}'", self.0) + } +} + +impl std::error::Error for InvalidTftpMode {} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum TftpMode { + NetAscii, + Octet, + Mail, +} + +impl FromStr for TftpMode { + type Err = InvalidTftpMode; + + fn from_str(s: &str) -> std::result::Result { + match s.to_lowercase().as_str() { + "netascii" => Ok(Self::NetAscii), + "octet" => Ok(Self::Octet), + "mail" => Ok(Self::Mail), + _ => Err(InvalidTftpMode(s.to_string())), + } + } +} + +#[derive(Debug)] +pub enum TftpPacket { + Request(TftpRequestPacket), + Data(TftpDataPacket), + Ack(TftpAckPacket), + OAck(TftpOAckPacket), + Error(TftpErrorPacket), +} + +impl TftpPacket { + pub fn write(&self, writer: W) -> Result<()> { + match self { + TftpPacket::Request(p) => p.write(writer), + TftpPacket::Data(p) => p.write(writer), + TftpPacket::Ack(p) => p.write(writer), + TftpPacket::OAck(p) => p.write(writer), + TftpPacket::Error(p) => p.write(writer), + } + } +} + +#[derive(Debug)] +pub struct TftpRequestPacket { + pub filename: String, + pub mode: TftpMode, + pub tsize: Option, + pub blksize: Option, +} + +impl TftpRequestPacket { + pub fn write(&self, mut writer: W) -> Result<()> { + todo!() + } +} + +#[derive(Debug)] +pub struct TftpDataPacket { + pub block: u16, + pub data: Vec, +} + +impl TftpDataPacket { + pub fn write(&self, mut writer: W) -> Result<()> { + wire::write_u16(&mut writer, TftpOp::Data.into())?; + wire::write_u16(&mut writer, self.block)?; + wire::write(&mut writer, &self.data)?; + Ok(()) + } +} + +#[derive(Debug)] +pub struct TftpAckPacket { + pub block: u16, +} + +impl TftpAckPacket { + pub fn write(&self, mut writer: W) -> Result<()> { + wire::write_u16(&mut writer, TftpOp::Data.into())?; + wire::write_u16(&mut writer, self.block)?; + Ok(()) + } +} + +#[derive(Debug)] +pub struct TftpOAckPacket { + pub tsize: Option, + pub blksize: Option, +} + +impl TftpOAckPacket { + pub fn write(&self, mut writer: W) -> Result<()> { + wire::write_u16(&mut writer, TftpOp::Oack.into())?; + + // Only include options that were requested by the client + if let Some(blksize_val) = self.blksize { + wire::write(&mut writer, b"blksize")?; + wire::write_u8(&mut writer, 0)?; // null terminator + let blksize_str = blksize_val.to_string(); + wire::write(&mut writer, blksize_str.as_bytes())?; + wire::write_u8(&mut writer, 0)?; // null terminator + } + + if let Some(tsize_val) = self.tsize { + wire::write(&mut writer, b"tsize")?; + wire::write_u8(&mut writer, 0)?; // null terminator + let tsize_str = tsize_val.to_string(); + wire::write(&mut writer, tsize_str.as_bytes())?; + wire::write_u8(&mut writer, 0)?; // null terminator + } + + Ok(()) + } +} + +#[derive(Debug)] +pub struct TftpErrorPacket { + pub code: u16, + pub message: String, +} + +impl TftpErrorPacket { + pub fn write(&self, mut writer: W) -> Result<()> { + wire::write_u16(&mut writer, TftpOp::Error.into())?; + wire::write_u16(&mut writer, self.code)?; + wire::write_null_terminated_string(&mut writer, &self.message)?; + Ok(()) + } +} + +pub fn parse_packet(buf: &[u8]) -> Result { + let mut cursor = Cursor::new(buf); + let op = TftpOp::try_from(wire::read_u16(&mut cursor)?).unwrap(); + + match op { + TftpOp::ReadRequest => { + let filename = wire::read_null_terminated_string(&mut cursor)?; + let mode = wire::read_null_terminated_string(&mut cursor)? + .parse::() + .unwrap(); + + let mut tsize = None; + let mut blksize = None; + + while let Ok(opt_name) = wire::read_null_terminated_string(&mut cursor) { + if opt_name.is_empty() { + break; + } + let opt_data = wire::read_null_terminated_string(&mut cursor)?; + match opt_name.as_str() { + "tsize" => tsize = Some(opt_data.parse::().unwrap()), + "blksize" => blksize = Some(opt_data.parse::().unwrap()), + _ => eprintln!("unknown tftp request option '{opt_name}'"), + } + } + + Ok(TftpPacket::Request(TftpRequestPacket { + filename, + mode, + tsize, + blksize, + })) + } + TftpOp::WriteRequest => unimplemented!(), + TftpOp::Data => { + let block = wire::read_u16(&mut cursor)?; + let mut data = Vec::new(); + cursor.read_to_end(&mut data)?; + Ok(TftpPacket::Data(TftpDataPacket { block, data })) + } + TftpOp::Ack => { + let block = wire::read_u16(&mut cursor)?; + Ok(TftpPacket::Ack(TftpAckPacket { block })) + } + TftpOp::Error => { + let code = wire::read_u16(&mut cursor)?; + let message = wire::read_null_terminated_string(&mut cursor)?; + Ok(TftpPacket::Error(TftpErrorPacket { code, message })) + } + TftpOp::Oack => { + let mut tsize = None; + let mut blksize = None; + + while let Ok(opt_name) = wire::read_null_terminated_string(&mut cursor) { + if opt_name.is_empty() { + break; + } + let opt_data = wire::read_null_terminated_string(&mut cursor)?; + match opt_name.as_str() { + "tsize" => tsize = Some(opt_data.parse::().unwrap()), + "blksize" => blksize = Some(opt_data.parse::().unwrap()), + _ => eprintln!("unknown tftp ack option '{opt_name}'"), + } + } + Ok(TftpPacket::OAck(TftpOAckPacket { tsize, blksize })) + } + } +} + +pub fn serve(dir: &Path) -> Result<()> { + let socket = UdpSocket::bind(format!("0.0.0.0:{PORT}"))?; + + // TODO: this needs to be done per addr + let mut last_blksize = 512u64; + let mut current_file = PathBuf::default(); + + loop { + let mut buf = [0u8; 1500]; + let (n, addr) = socket.recv_from(&mut buf)?; + let packet = parse_packet(&buf[..n]).unwrap(); + + let response = match packet { + TftpPacket::Request(req) => { + println!( + "Request options: tsize={:?}, blksize={:?}", + req.tsize, req.blksize + ); + + let filepath = dir.join(req.filename); + current_file = filepath.clone(); + let meta = std::fs::metadata(&filepath).unwrap(); + let actual_file_size = meta.len(); + + // Only send OACK if client sent options + if req.tsize.is_some() || req.blksize.is_some() { + if let Some(blksize) = req.blksize { + last_blksize = blksize; + } + + let tsize_response = if req.tsize.is_some() { + Some(actual_file_size) + } else { + None + }; + + Some(TftpPacket::OAck(TftpOAckPacket { + tsize: req.tsize, + blksize: req.blksize, + })) + } else { + // No options, send first data block directly + let contents = std::fs::read(&filepath).unwrap(); + let block_size = 512; + let first_block = if contents.len() > block_size { + contents[..block_size].to_vec() + } else { + contents + }; + + Some(TftpPacket::Data(TftpDataPacket { + block: 1, + data: first_block, + })) + } + } + TftpPacket::Data(dat) => unimplemented!(), + TftpPacket::Ack(ack) => { + println!("Received ACK packet: block {}", ack.block); + + let contents = std::fs::read(¤t_file).unwrap(); + let next_block = ack.block + 1; + let start_offset = (next_block - 1) as u64 * last_blksize; + let end_offset = next_block as u64 * last_blksize; + let prev_start_offset = (next_block.saturating_sub(2)) as u64 * last_blksize; + let prev_remain = contents.len() - prev_start_offset as usize; + if prev_remain as u64 >= last_blksize || ack.block == 0 { + let end = std::cmp::min(end_offset as usize, contents.len()); + let block_data = contents[start_offset as usize..end].to_vec(); + println!("sending tftp data packet with {} bytes", block_data.len()); + Some(TftpPacket::Data(TftpDataPacket { + block: next_block, + data: block_data, + })) + } else { + None + } + } + TftpPacket::OAck(ack) => todo!(), + TftpPacket::Error(err) => { + println!( + "Received ERROR packet: code {}, message: {}", + err.code, err.message + ); + None + } + }; + + if let Some(response) = response { + let mut writer = Cursor::new(&mut buf[..]); + response.write(&mut writer).unwrap(); + let (response, _) = writer.split(); + socket.send_to(&response, addr).unwrap(); + } + } +} -- cgit