aboutsummaryrefslogtreecommitdiff
path: root/src/tftp.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/tftp.rs')
-rw-r--r--src/tftp.rs377
1 files changed, 287 insertions, 90 deletions
diff --git a/src/tftp.rs b/src/tftp.rs
index d986a44..72bac22 100644
--- a/src/tftp.rs
+++ b/src/tftp.rs
@@ -1,7 +1,7 @@
1use std::{ 1use std::{
2 collections::HashMap,
2 io::{Cursor, Read as _, Result, Write}, 3 io::{Cursor, Read as _, Result, Write},
3 net::UdpSocket, 4 net::SocketAddr,
4 path::{Path, PathBuf},
5 str::FromStr, 5 str::FromStr,
6}; 6};
7 7
@@ -9,6 +9,9 @@ use crate::wire;
9 9
10pub const PORT: u16 = 69; 10pub const PORT: u16 = 69;
11 11
12const DEFAULT_BLOCK_SIZE: u64 = 512;
13const MAX_BLOCK_SIZE: usize = 2048;
14
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] 15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub struct InvalidTftpOp(u16); 16pub struct InvalidTftpOp(u16);
14 17
@@ -91,7 +94,7 @@ impl FromStr for TftpMode {
91 94
92#[derive(Debug)] 95#[derive(Debug)]
93pub enum TftpPacket { 96pub enum TftpPacket {
94 Request(TftpRequestPacket), 97 ReadRequest(TftpReadRequestPacket),
95 Data(TftpDataPacket), 98 Data(TftpDataPacket),
96 Ack(TftpAckPacket), 99 Ack(TftpAckPacket),
97 OAck(TftpOAckPacket), 100 OAck(TftpOAckPacket),
@@ -99,9 +102,13 @@ pub enum TftpPacket {
99} 102}
100 103
101impl TftpPacket { 104impl TftpPacket {
105 pub fn parse(buf: &[u8]) -> Result<Self> {
106 parse_packet(buf)
107 }
108
102 pub fn write<W: Write>(&self, writer: W) -> Result<()> { 109 pub fn write<W: Write>(&self, writer: W) -> Result<()> {
103 match self { 110 match self {
104 TftpPacket::Request(p) => p.write(writer), 111 TftpPacket::ReadRequest(p) => p.write(writer),
105 TftpPacket::Data(p) => p.write(writer), 112 TftpPacket::Data(p) => p.write(writer),
106 TftpPacket::Ack(p) => p.write(writer), 113 TftpPacket::Ack(p) => p.write(writer),
107 TftpPacket::OAck(p) => p.write(writer), 114 TftpPacket::OAck(p) => p.write(writer),
@@ -111,14 +118,14 @@ impl TftpPacket {
111} 118}
112 119
113#[derive(Debug)] 120#[derive(Debug)]
114pub struct TftpRequestPacket { 121pub struct TftpReadRequestPacket {
115 pub filename: String, 122 pub filename: String,
116 pub mode: TftpMode, 123 pub mode: TftpMode,
117 pub tsize: Option<u64>, 124 pub tsize: Option<u64>,
118 pub blksize: Option<u64>, 125 pub blksize: Option<u64>,
119} 126}
120 127
121impl TftpRequestPacket { 128impl TftpReadRequestPacket {
122 pub fn write<W: Write>(&self, mut writer: W) -> Result<()> { 129 pub fn write<W: Write>(&self, mut writer: W) -> Result<()> {
123 todo!() 130 todo!()
124 } 131 }
@@ -183,16 +190,85 @@ impl TftpOAckPacket {
183 } 190 }
184} 191}
185 192
193#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
194pub enum TftpErrorCode {
195 Undefined,
196 FileNotFound,
197 AccessViolation,
198 DiskFull,
199 IllegalOperation,
200 UnknownTransferId,
201 FileAreadyExists,
202 NoSuchUser,
203 Unknown(u16),
204}
205
206impl std::fmt::Display for TftpErrorCode {
207 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
208 write!(f, "{:?} ({})", *self, u16::from(*self))
209 }
210}
211
212impl TftpErrorCode {
213 pub const CODE_UNDEFINED: u16 = 0;
214 pub const CODE_FILE_NOT_FOUND: u16 = 1;
215 pub const CODE_ACCESS_VIOLATION: u16 = 2;
216 pub const CODE_DISK_FULL: u16 = 3;
217 pub const CODE_ILLEGAL_OPERATION: u16 = 4;
218 pub const CODE_UNKNOWN_TRANSFER_ID: u16 = 5;
219 pub const CODE_FILE_ALREADY_EXISTS: u16 = 6;
220 pub const CODE_NO_SUCH_USER: u16 = 7;
221}
222
223impl From<u16> for TftpErrorCode {
224 fn from(value: u16) -> Self {
225 match value {
226 Self::CODE_UNDEFINED => Self::Undefined,
227 Self::CODE_FILE_NOT_FOUND => Self::FileNotFound,
228 Self::CODE_ACCESS_VIOLATION => Self::AccessViolation,
229 Self::CODE_DISK_FULL => Self::DiskFull,
230 Self::CODE_ILLEGAL_OPERATION => Self::IllegalOperation,
231 Self::CODE_UNKNOWN_TRANSFER_ID => Self::UnknownTransferId,
232 Self::CODE_FILE_ALREADY_EXISTS => Self::FileAreadyExists,
233 Self::CODE_NO_SUCH_USER => Self::NoSuchUser,
234 unknown => Self::Unknown(unknown),
235 }
236 }
237}
238
239impl From<TftpErrorCode> for u16 {
240 fn from(value: TftpErrorCode) -> Self {
241 match value {
242 TftpErrorCode::Undefined => TftpErrorCode::CODE_UNDEFINED,
243 TftpErrorCode::FileNotFound => TftpErrorCode::CODE_FILE_NOT_FOUND,
244 TftpErrorCode::AccessViolation => TftpErrorCode::CODE_ACCESS_VIOLATION,
245 TftpErrorCode::DiskFull => TftpErrorCode::CODE_DISK_FULL,
246 TftpErrorCode::IllegalOperation => TftpErrorCode::CODE_ILLEGAL_OPERATION,
247 TftpErrorCode::UnknownTransferId => TftpErrorCode::CODE_UNKNOWN_TRANSFER_ID,
248 TftpErrorCode::FileAreadyExists => TftpErrorCode::CODE_FILE_ALREADY_EXISTS,
249 TftpErrorCode::NoSuchUser => TftpErrorCode::CODE_NO_SUCH_USER,
250 TftpErrorCode::Unknown(code) => code,
251 }
252 }
253}
254
186#[derive(Debug)] 255#[derive(Debug)]
187pub struct TftpErrorPacket { 256pub struct TftpErrorPacket {
188 pub code: u16, 257 pub code: TftpErrorCode,
189 pub message: String, 258 pub message: String,
190} 259}
191 260
192impl TftpErrorPacket { 261impl TftpErrorPacket {
262 pub fn new(code: TftpErrorCode, message: impl Into<String>) -> Self {
263 Self {
264 code,
265 message: message.into(),
266 }
267 }
268
193 pub fn write<W: Write>(&self, mut writer: W) -> Result<()> { 269 pub fn write<W: Write>(&self, mut writer: W) -> Result<()> {
194 wire::write_u16(&mut writer, TftpOp::Error.into())?; 270 wire::write_u16(&mut writer, TftpOp::Error.into())?;
195 wire::write_u16(&mut writer, self.code)?; 271 wire::write_u16(&mut writer, u16::from(self.code))?;
196 wire::write_null_terminated_string(&mut writer, &self.message)?; 272 wire::write_null_terminated_string(&mut writer, &self.message)?;
197 Ok(()) 273 Ok(())
198 } 274 }
@@ -224,7 +300,7 @@ pub fn parse_packet(buf: &[u8]) -> Result<TftpPacket> {
224 } 300 }
225 } 301 }
226 302
227 Ok(TftpPacket::Request(TftpRequestPacket { 303 Ok(TftpPacket::ReadRequest(TftpReadRequestPacket {
228 filename, 304 filename,
229 mode, 305 mode,
230 tsize, 306 tsize,
@@ -243,7 +319,7 @@ pub fn parse_packet(buf: &[u8]) -> Result<TftpPacket> {
243 Ok(TftpPacket::Ack(TftpAckPacket { block })) 319 Ok(TftpPacket::Ack(TftpAckPacket { block }))
244 } 320 }
245 TftpOp::Error => { 321 TftpOp::Error => {
246 let code = wire::read_u16(&mut cursor)?; 322 let code = TftpErrorCode::from(wire::read_u16(&mut cursor)?);
247 let message = wire::read_null_terminated_string(&mut cursor)?; 323 let message = wire::read_null_terminated_string(&mut cursor)?;
248 Ok(TftpPacket::Error(TftpErrorPacket { code, message })) 324 Ok(TftpPacket::Error(TftpErrorPacket { code, message }))
249 } 325 }
@@ -267,101 +343,222 @@ pub fn parse_packet(buf: &[u8]) -> Result<TftpPacket> {
267 } 343 }
268} 344}
269 345
270pub fn serve(dir: impl AsRef<Path>) -> Result<()> { 346pub trait FileSystem {
271 let dir = dir.as_ref(); 347 fn stat(&self, filename: &str) -> Result<u64>;
272 let socket = UdpSocket::bind(format!("0.0.0.0:{PORT}"))?; 348 fn read(&self, filename: &str, offset: u64, buf: &mut [u8]) -> Result<u64>;
349}
273 350
274 // TODO: this needs to be done per addr 351#[derive(Debug)]
275 let mut last_blksize = 512u64; 352pub struct StaticFileSystem {
276 let mut current_file = PathBuf::default(); 353 files: &'static [(&'static str, &'static [u8])],
354}
355
356impl StaticFileSystem {
357 pub fn new(files: &'static [(&'static str, &'static [u8])]) -> Self {
358 Self { files }
359 }
277 360
278 loop { 361 fn find_file(&self, filename: &str) -> Result<&'static [u8]> {
279 let mut buf = [0u8; 1500]; 362 self.files
280 let (n, addr) = socket.recv_from(&mut buf)?; 363 .iter()
281 let packet = parse_packet(&buf[..n]).unwrap(); 364 .find(|(name, _)| *name == filename)
365 .map(|(_, contents)| *contents)
366 .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::NotFound, "file not found"))
367 }
368}
369
370impl FileSystem for StaticFileSystem {
371 fn stat(&self, filename: &str) -> Result<u64> {
372 let file = self.find_file(filename)?;
373 Ok(u64::try_from(file.len()).unwrap())
374 }
375
376 fn read(&self, filename: &str, offset: u64, buf: &mut [u8]) -> Result<u64> {
377 let file = self.find_file(filename)?;
378 let offset = usize::try_from(offset).unwrap();
379 if offset >= file.len() {
380 return Ok(0);
381 }
382
383 let rem = &file[offset..];
384 let copy_n = rem.len().min(buf.len());
385 buf[..copy_n].copy_from_slice(&rem[..copy_n]);
386 Ok(u64::try_from(copy_n).unwrap())
387 }
388}
282 389
283 let response = match packet { 390#[derive(Debug)]
284 TftpPacket::Request(req) => { 391struct Client {
392 blksize: u64,
393 filename: Option<String>,
394}
395
396impl Default for Client {
397 fn default() -> Self {
398 Self {
399 blksize: DEFAULT_BLOCK_SIZE,
400 filename: None,
401 }
402 }
403}
404
405pub enum ServerCommand {
406 Send(TftpPacket),
407 Ignore,
408}
409
410impl ServerCommand {
411 fn error(code: TftpErrorCode, message: impl Into<String>) -> Self {
412 Self::Send(TftpPacket::Error(TftpErrorPacket::new(code, message)))
413 }
414}
415
416#[derive(Debug, Default)]
417pub struct Server {
418 clients: HashMap<SocketAddr, Client>,
419}
420
421impl Server {
422 pub fn process(
423 &mut self,
424 fs: &dyn FileSystem,
425 source: SocketAddr,
426 buf: &[u8],
427 ) -> ServerCommand {
428 let packet = match TftpPacket::parse(buf) {
429 Ok(packet) => packet,
430 Err(err) => {
431 return ServerCommand::Send(TftpPacket::Error(TftpErrorPacket::new(
432 TftpErrorCode::Undefined,
433 format!("invalid packet: {err}"),
434 )));
435 }
436 };
437
438 match packet {
439 TftpPacket::ReadRequest(req) => self.process_read_req(fs, source, &req),
440 TftpPacket::Ack(ack) => self.process_ack(fs, source, &ack),
441 TftpPacket::Error(err) => {
285 println!( 442 println!(
286 "Request options: tsize={:?}, blksize={:?}", 443 "received error from client {}: ({}) {}",
287 req.tsize, req.blksize 444 source, err.code, err.message
288 ); 445 );
446 self.clients.remove(&source);
447 ServerCommand::Ignore
448 }
449 TftpPacket::Data(_) | TftpPacket::OAck(_) => ServerCommand::Ignore,
450 }
451 }
289 452
290 let filepath = dir.join(req.filename); 453 fn process_read_req(
291 current_file = filepath.clone(); 454 &mut self,
292 let meta = std::fs::metadata(&filepath).unwrap(); 455 fs: &dyn FileSystem,
293 let actual_file_size = meta.len(); 456 source: SocketAddr,
457 req: &TftpReadRequestPacket,
458 ) -> ServerCommand {
459 println!(
460 "Request options: tsize={:?}, blksize={:?}",
461 req.tsize, req.blksize
462 );
463
464 let client = self.clients.entry(source).or_default();
465 client.filename = Some(req.filename.clone());
466
467 // Only send OACK if client sent options
468 if req.tsize.is_some() || req.blksize.is_some() {
469 if let Some(blksize) = req.blksize {
470 client.blksize = blksize;
471 }
294 472
295 // Only send OACK if client sent options 473 let tsize_response = if req.tsize.is_some() {
296 if req.tsize.is_some() || req.blksize.is_some() { 474 let filesize = match fs.stat(&req.filename) {
297 if let Some(blksize) = req.blksize { 475 Ok(filesize) => filesize,
298 last_blksize = blksize; 476 Err(err) => {
477 return ServerCommand::error(
478 TftpErrorCode::Undefined,
479 format!("failed to obtain file size: {}", err),
480 );
299 } 481 }
482 };
300 483
301 let tsize_response = if req.tsize.is_some() { 484 Some(filesize)
302 Some(actual_file_size) 485 } else {
303 } else { 486 None
304 None 487 };
305 }; 488
306 489 ServerCommand::Send(TftpPacket::OAck(TftpOAckPacket {
307 Some(TftpPacket::OAck(TftpOAckPacket { 490 tsize: tsize_response,
308 tsize: tsize_response, 491 blksize: req.blksize,
309 blksize: req.blksize, 492 }))
310 })) 493 } else {
311 } else { 494 // No options, send first data block directly
312 // No options, send first data block directly 495 let options = self.clients.entry(source).or_default();
313 let contents = std::fs::read(&filepath).unwrap(); 496 let block_size = usize::try_from(options.blksize).unwrap();
314 let block_size = 512; 497
315 let first_block = if contents.len() > block_size { 498 assert!(block_size <= MAX_BLOCK_SIZE);
316 contents[..block_size].to_vec() 499 let mut contents = [0u8; MAX_BLOCK_SIZE];
317 } else { 500 let contents = &mut contents[..block_size];
318 contents 501
319 }; 502 let n = match fs.read(&req.filename, 0, contents) {
320 503 Ok(n) => usize::try_from(n).unwrap(),
321 Some(TftpPacket::Data(TftpDataPacket { 504 Err(err) => {
322 block: 1, 505 return ServerCommand::error(
323 data: first_block, 506 TftpErrorCode::Undefined,
324 })) 507 format!("failed to read file contents: {}", err),
325 } 508 );
326 }
327 TftpPacket::Data(dat) => unimplemented!(),
328 TftpPacket::Ack(ack) => {
329 println!("Received ACK packet: block {}", ack.block);
330
331 let contents = std::fs::read(&current_file).unwrap();
332 let next_block = ack.block + 1;
333 let start_offset = (next_block - 1) as u64 * last_blksize;
334 let end_offset = next_block as u64 * last_blksize;
335 let prev_start_offset = (next_block.saturating_sub(2)) as u64 * last_blksize;
336 let prev_remain = contents.len() - prev_start_offset as usize;
337 if prev_remain as u64 >= last_blksize || ack.block == 0 {
338 let end = std::cmp::min(end_offset as usize, contents.len());
339 let block_data = contents[start_offset as usize..end].to_vec();
340 println!("sending tftp data packet with {} bytes", block_data.len());
341 Some(TftpPacket::Data(TftpDataPacket {
342 block: next_block,
343 data: block_data,
344 }))
345 } else {
346 None
347 } 509 }
510 };
511
512 ServerCommand::Send(TftpPacket::Data(TftpDataPacket {
513 block: 1,
514 data: contents[..n].to_vec(),
515 }))
516 }
517 }
518
519 fn process_ack(
520 &mut self,
521 fs: &dyn FileSystem,
522 source: SocketAddr,
523 ack: &TftpAckPacket,
524 ) -> ServerCommand {
525 println!("Received ACK packet: block {}", ack.block);
526
527 let client = self.clients.entry(source).or_default();
528 let filename = match &client.filename {
529 Some(filename) => filename,
530 None => {
531 return ServerCommand::error(
532 TftpErrorCode::Undefined,
533 "unknown filename for client",
534 );
348 } 535 }
349 TftpPacket::OAck(ack) => todo!(), 536 };
350 TftpPacket::Error(err) => { 537 let block_size = usize::try_from(client.blksize).unwrap();
351 println!( 538
352 "Received ERROR packet: code {}, message: {}", 539 let next_block = ack.block + 1;
353 err.code, err.message 540 let start_offset = (usize::from(next_block) - 1) * block_size;
541
542 let mut contents = [0u8; MAX_BLOCK_SIZE];
543 let contents = &mut contents[..block_size];
544 let n = match fs.read(filename, u64::try_from(start_offset).unwrap(), contents) {
545 Ok(n) => usize::try_from(n).unwrap(),
546 Err(err) => {
547 return ServerCommand::error(
548 TftpErrorCode::Undefined,
549 format!("failed to read file contents: {}", err),
354 ); 550 );
355 None
356 } 551 }
357 }; 552 };
553 let contents = &contents[..n];
358 554
359 if let Some(response) = response { 555 if contents.is_empty() {
360 let mut writer = Cursor::new(&mut buf[..]); 556 return ServerCommand::Ignore;
361 println!("Sending to {addr}: {response:#?}");
362 response.write(&mut writer).unwrap();
363 let (response, _) = writer.split();
364 socket.send_to(&response, addr).unwrap();
365 } 557 }
558
559 ServerCommand::Send(TftpPacket::Data(TftpDataPacket {
560 block: next_block,
561 data: contents.to_vec(),
562 }))
366 } 563 }
367} 564}