use std::{ collections::HashMap, io::{Cursor, Read as _, Result, Write}, net::SocketAddr, str::FromStr, }; use crate::wire; pub const PORT: u16 = 69; const DEFAULT_BLOCK_SIZE: u64 = 512; const MAX_BLOCK_SIZE: usize = 2048; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct InvalidTftpOp(u16); impl std::fmt::Display for InvalidTftpOp { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "invalid tftp opcode '{}'", self.0) } } impl std::error::Error for InvalidTftpOp {} #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum TftpOp { ReadRequest, WriteRequest, Data, Ack, Error, Oack, } impl TftpOp { pub const CODE_READ_REQUEST: u16 = 1; pub const CODE_WRITE_REQUEST: u16 = 2; pub const CODE_DATA: u16 = 3; pub const CODE_ACK: u16 = 4; pub const CODE_ERROR: u16 = 5; pub const CODE_OACK: u16 = 6; } impl From for u16 { fn from(value: TftpOp) -> Self { match value { TftpOp::ReadRequest => TftpOp::CODE_READ_REQUEST, TftpOp::WriteRequest => TftpOp::CODE_WRITE_REQUEST, TftpOp::Data => TftpOp::CODE_DATA, TftpOp::Ack => TftpOp::CODE_ACK, TftpOp::Error => TftpOp::CODE_ERROR, TftpOp::Oack => TftpOp::CODE_OACK, } } } impl TryFrom for TftpOp { type Error = InvalidTftpOp; fn try_from(value: u16) -> std::result::Result { match value { Self::CODE_READ_REQUEST => Ok(Self::ReadRequest), Self::CODE_WRITE_REQUEST => Ok(Self::WriteRequest), Self::CODE_DATA => Ok(Self::Data), Self::CODE_ACK => Ok(Self::Ack), Self::CODE_ERROR => Ok(Self::Error), Self::CODE_OACK => Ok(Self::Oack), unknown => Err(InvalidTftpOp(unknown)), } } } #[derive(Debug)] pub struct InvalidTftpMode(String); impl std::fmt::Display for InvalidTftpMode { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "invalid tftp mode '{}'", self.0) } } impl std::error::Error for InvalidTftpMode {} #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum TftpMode { NetAscii, Octet, Mail, } impl FromStr for TftpMode { type Err = InvalidTftpMode; fn from_str(s: &str) -> std::result::Result { match s.to_lowercase().as_str() { "netascii" => Ok(Self::NetAscii), "octet" => Ok(Self::Octet), "mail" => Ok(Self::Mail), _ => Err(InvalidTftpMode(s.to_string())), } } } #[derive(Debug)] pub enum TftpPacket { ReadRequest(TftpReadRequestPacket), Data(TftpDataPacket), Ack(TftpAckPacket), OAck(TftpOAckPacket), Error(TftpErrorPacket), } impl TftpPacket { pub fn parse(buf: &[u8]) -> Result { parse_packet(buf) } pub fn write(&self, writer: W) -> Result<()> { match self { TftpPacket::ReadRequest(p) => p.write(writer), TftpPacket::Data(p) => p.write(writer), TftpPacket::Ack(p) => p.write(writer), TftpPacket::OAck(p) => p.write(writer), TftpPacket::Error(p) => p.write(writer), } } } #[derive(Debug)] pub struct TftpReadRequestPacket { pub filename: String, pub mode: TftpMode, pub tsize: Option, pub blksize: Option, } impl TftpReadRequestPacket { pub fn write(&self, _writer: W) -> Result<()> { todo!() } } #[derive(Debug)] pub struct TftpDataPacket { pub block: u16, pub data: Vec, } impl TftpDataPacket { pub fn write(&self, mut writer: W) -> Result<()> { wire::write_u16(&mut writer, TftpOp::Data.into())?; wire::write_u16(&mut writer, self.block)?; wire::write(&mut writer, &self.data)?; Ok(()) } } #[derive(Debug)] pub struct TftpAckPacket { pub block: u16, } impl TftpAckPacket { pub fn write(&self, mut writer: W) -> Result<()> { wire::write_u16(&mut writer, TftpOp::Data.into())?; wire::write_u16(&mut writer, self.block)?; Ok(()) } } #[derive(Debug)] pub struct TftpOAckPacket { pub tsize: Option, pub blksize: Option, } impl TftpOAckPacket { pub fn write(&self, mut writer: W) -> Result<()> { wire::write_u16(&mut writer, TftpOp::Oack.into())?; // Only include options that were requested by the client if let Some(blksize_val) = self.blksize { wire::write(&mut writer, b"blksize")?; wire::write_u8(&mut writer, 0)?; // null terminator let blksize_str = blksize_val.to_string(); wire::write(&mut writer, blksize_str.as_bytes())?; wire::write_u8(&mut writer, 0)?; // null terminator } if let Some(tsize_val) = self.tsize { wire::write(&mut writer, b"tsize")?; wire::write_u8(&mut writer, 0)?; // null terminator let tsize_str = tsize_val.to_string(); wire::write(&mut writer, tsize_str.as_bytes())?; wire::write_u8(&mut writer, 0)?; // null terminator } Ok(()) } } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum TftpErrorCode { Undefined, FileNotFound, AccessViolation, DiskFull, IllegalOperation, UnknownTransferId, FileAreadyExists, NoSuchUser, Unknown(u16), } impl std::fmt::Display for TftpErrorCode { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{:?} ({})", *self, u16::from(*self)) } } impl TftpErrorCode { pub const CODE_UNDEFINED: u16 = 0; pub const CODE_FILE_NOT_FOUND: u16 = 1; pub const CODE_ACCESS_VIOLATION: u16 = 2; pub const CODE_DISK_FULL: u16 = 3; pub const CODE_ILLEGAL_OPERATION: u16 = 4; pub const CODE_UNKNOWN_TRANSFER_ID: u16 = 5; pub const CODE_FILE_ALREADY_EXISTS: u16 = 6; pub const CODE_NO_SUCH_USER: u16 = 7; } impl From for TftpErrorCode { fn from(value: u16) -> Self { match value { Self::CODE_UNDEFINED => Self::Undefined, Self::CODE_FILE_NOT_FOUND => Self::FileNotFound, Self::CODE_ACCESS_VIOLATION => Self::AccessViolation, Self::CODE_DISK_FULL => Self::DiskFull, Self::CODE_ILLEGAL_OPERATION => Self::IllegalOperation, Self::CODE_UNKNOWN_TRANSFER_ID => Self::UnknownTransferId, Self::CODE_FILE_ALREADY_EXISTS => Self::FileAreadyExists, Self::CODE_NO_SUCH_USER => Self::NoSuchUser, unknown => Self::Unknown(unknown), } } } impl From for u16 { fn from(value: TftpErrorCode) -> Self { match value { TftpErrorCode::Undefined => TftpErrorCode::CODE_UNDEFINED, TftpErrorCode::FileNotFound => TftpErrorCode::CODE_FILE_NOT_FOUND, TftpErrorCode::AccessViolation => TftpErrorCode::CODE_ACCESS_VIOLATION, TftpErrorCode::DiskFull => TftpErrorCode::CODE_DISK_FULL, TftpErrorCode::IllegalOperation => TftpErrorCode::CODE_ILLEGAL_OPERATION, TftpErrorCode::UnknownTransferId => TftpErrorCode::CODE_UNKNOWN_TRANSFER_ID, TftpErrorCode::FileAreadyExists => TftpErrorCode::CODE_FILE_ALREADY_EXISTS, TftpErrorCode::NoSuchUser => TftpErrorCode::CODE_NO_SUCH_USER, TftpErrorCode::Unknown(code) => code, } } } #[derive(Debug)] pub struct TftpErrorPacket { pub code: TftpErrorCode, pub message: String, } impl TftpErrorPacket { pub fn new(code: TftpErrorCode, message: impl Into) -> Self { Self { code, message: message.into(), } } pub fn write(&self, mut writer: W) -> Result<()> { wire::write_u16(&mut writer, TftpOp::Error.into())?; wire::write_u16(&mut writer, u16::from(self.code))?; wire::write_null_terminated_string(&mut writer, &self.message)?; Ok(()) } } pub fn parse_packet(buf: &[u8]) -> Result { let mut cursor = Cursor::new(buf); let op = TftpOp::try_from(wire::read_u16(&mut cursor)?).unwrap(); match op { TftpOp::ReadRequest => { let filename = wire::read_null_terminated_string(&mut cursor)?; let mode = wire::read_null_terminated_string(&mut cursor)? .parse::() .unwrap(); let mut tsize = None; let mut blksize = None; while let Ok(opt_name) = wire::read_null_terminated_string(&mut cursor) { if opt_name.is_empty() { break; } let opt_data = wire::read_null_terminated_string(&mut cursor)?; match opt_name.as_str() { "tsize" => tsize = Some(opt_data.parse::().unwrap()), "blksize" => blksize = Some(opt_data.parse::().unwrap()), _ => eprintln!("unknown tftp request option '{opt_name}'"), } } Ok(TftpPacket::ReadRequest(TftpReadRequestPacket { filename, mode, tsize, blksize, })) } TftpOp::WriteRequest => unimplemented!(), TftpOp::Data => { let block = wire::read_u16(&mut cursor)?; let mut data = Vec::new(); cursor.read_to_end(&mut data)?; Ok(TftpPacket::Data(TftpDataPacket { block, data })) } TftpOp::Ack => { let block = wire::read_u16(&mut cursor)?; Ok(TftpPacket::Ack(TftpAckPacket { block })) } TftpOp::Error => { let code = TftpErrorCode::from(wire::read_u16(&mut cursor)?); let message = wire::read_null_terminated_string(&mut cursor)?; Ok(TftpPacket::Error(TftpErrorPacket { code, message })) } TftpOp::Oack => { let mut tsize = None; let mut blksize = None; while let Ok(opt_name) = wire::read_null_terminated_string(&mut cursor) { if opt_name.is_empty() { break; } let opt_data = wire::read_null_terminated_string(&mut cursor)?; match opt_name.as_str() { "tsize" => tsize = Some(opt_data.parse::().unwrap()), "blksize" => blksize = Some(opt_data.parse::().unwrap()), _ => eprintln!("unknown tftp ack option '{opt_name}'"), } } Ok(TftpPacket::OAck(TftpOAckPacket { tsize, blksize })) } } } pub trait FileSystem { fn stat(&self, filename: &str) -> Result; fn read(&self, filename: &str, offset: u64, buf: &mut [u8]) -> Result; } #[derive(Debug)] pub struct StaticFileSystem { files: &'static [(&'static str, &'static [u8])], } impl StaticFileSystem { pub fn new(files: &'static [(&'static str, &'static [u8])]) -> Self { Self { files } } fn find_file(&self, filename: &str) -> Result<&'static [u8]> { self.files .iter() .find(|(name, _)| *name == filename) .map(|(_, contents)| *contents) .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::NotFound, "file not found")) } } impl FileSystem for StaticFileSystem { fn stat(&self, filename: &str) -> Result { let file = self.find_file(filename)?; Ok(u64::try_from(file.len()).unwrap()) } fn read(&self, filename: &str, offset: u64, buf: &mut [u8]) -> Result { let file = self.find_file(filename)?; let offset = usize::try_from(offset).unwrap(); if offset >= file.len() { return Ok(0); } let rem = &file[offset..]; let copy_n = rem.len().min(buf.len()); buf[..copy_n].copy_from_slice(&rem[..copy_n]); Ok(u64::try_from(copy_n).unwrap()) } } #[derive(Debug)] struct Client { blksize: u64, filename: Option, } impl Default for Client { fn default() -> Self { Self { blksize: DEFAULT_BLOCK_SIZE, filename: None, } } } pub enum ServerCommand { Send(TftpPacket), Ignore, } impl ServerCommand { fn error(code: TftpErrorCode, message: impl Into) -> Self { Self::Send(TftpPacket::Error(TftpErrorPacket::new(code, message))) } } #[derive(Debug, Default)] pub struct Server { clients: HashMap, } impl Server { pub fn process( &mut self, fs: &dyn FileSystem, source: SocketAddr, buf: &[u8], ) -> ServerCommand { let packet = match TftpPacket::parse(buf) { Ok(packet) => packet, Err(err) => { return ServerCommand::Send(TftpPacket::Error(TftpErrorPacket::new( TftpErrorCode::Undefined, format!("invalid packet: {err}"), ))); } }; match packet { TftpPacket::ReadRequest(req) => self.process_read_req(fs, source, &req), TftpPacket::Ack(ack) => self.process_ack(fs, source, &ack), TftpPacket::Error(err) => { println!( "received error from client {}: ({}) {}", source, err.code, err.message ); self.clients.remove(&source); ServerCommand::Ignore } TftpPacket::Data(_) | TftpPacket::OAck(_) => ServerCommand::Ignore, } } fn process_read_req( &mut self, fs: &dyn FileSystem, source: SocketAddr, req: &TftpReadRequestPacket, ) -> ServerCommand { println!( "Request options: tsize={:?}, blksize={:?}", req.tsize, req.blksize ); let client = self.clients.entry(source).or_default(); client.filename = Some(req.filename.clone()); // Only send OACK if client sent options if req.tsize.is_some() || req.blksize.is_some() { if let Some(blksize) = req.blksize { client.blksize = blksize; } let tsize_response = if req.tsize.is_some() { let filesize = match fs.stat(&req.filename) { Ok(filesize) => filesize, Err(err) => { return ServerCommand::error( TftpErrorCode::Undefined, format!("failed to obtain file size: {}", err), ); } }; Some(filesize) } else { None }; ServerCommand::Send(TftpPacket::OAck(TftpOAckPacket { tsize: tsize_response, blksize: req.blksize, })) } else { // No options, send first data block directly let options = self.clients.entry(source).or_default(); let block_size = usize::try_from(options.blksize).unwrap(); assert!(block_size <= MAX_BLOCK_SIZE); let mut contents = [0u8; MAX_BLOCK_SIZE]; let contents = &mut contents[..block_size]; let n = match fs.read(&req.filename, 0, contents) { Ok(n) => usize::try_from(n).unwrap(), Err(err) => { return ServerCommand::error( TftpErrorCode::Undefined, format!("failed to read file contents: {}", err), ); } }; ServerCommand::Send(TftpPacket::Data(TftpDataPacket { block: 1, data: contents[..n].to_vec(), })) } } fn process_ack( &mut self, fs: &dyn FileSystem, source: SocketAddr, ack: &TftpAckPacket, ) -> ServerCommand { println!("Received ACK packet: block {}", ack.block); let client = self.clients.entry(source).or_default(); let filename = match &client.filename { Some(filename) => filename, None => { return ServerCommand::error( TftpErrorCode::Undefined, "unknown filename for client", ); } }; let block_size = usize::try_from(client.blksize).unwrap(); let next_block = ack.block + 1; let start_offset = (usize::from(next_block) - 1) * block_size; let mut contents = [0u8; MAX_BLOCK_SIZE]; let contents = &mut contents[..block_size]; let n = match fs.read(filename, u64::try_from(start_offset).unwrap(), contents) { Ok(n) => usize::try_from(n).unwrap(), Err(err) => { return ServerCommand::error( TftpErrorCode::Undefined, format!("failed to read file contents: {}", err), ); } }; let contents = &contents[..n]; if contents.is_empty() { return ServerCommand::Ignore; } ServerCommand::Send(TftpPacket::Data(TftpDataPacket { block: next_block, data: contents.to_vec(), })) } }