Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: Try encapsulating writes #5

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion zvt/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ zvt_builder = { version = "0.1.0", path = "../zvt_builder" }
log = "0.4.19"
env_logger = "0.10.0"
tokio-stream = "0.1.14"
tokio = { version = "1.29.1", features = ["net", "io-util", "rt-multi-thread", "macros"] }
tokio = { version = "1.29.1", features = ["net", "io-util", "rt-multi-thread", "macros", "fs"] }
async-stream = "0.3.5"
serde = { version = "1.0.185", features = ["derive"] }
serde_json = "1.0.105"
futures = "0.3.28"
pin-project = "1.1.3"

65 changes: 64 additions & 1 deletion zvt/src/bin/status/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,76 @@ fn init_logger() {
.init();
}

use std::fs::File;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

pub struct LoggingTcpStream {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the whole idea was to support not only tcp but also other ways of communication just as the python code did. Hardcoding tcp here is a regression IMHO

stream: TcpStream,
file: File,
}

impl LoggingTcpStream {
pub async fn new(stream: TcpStream, file_path: &str) -> io::Result<Self> {
let file = File::create(file_path)?;
Ok(LoggingTcpStream { stream, file })
}
}

impl AsyncRead for LoggingTcpStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let self_mut = self.get_mut();
let read = buf.filled().len();
let poll = Pin::new(&mut self_mut.stream).poll_read(cx, buf);
if let Poll::Ready(Ok(_)) = poll {
let buf = &buf.filled()[read..];
self_mut.file.write_all(buf)?;
}
poll
}
}

impl AsyncWrite for LoggingTcpStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let self_mut = self.get_mut();
let poll = Pin::new(&mut self_mut.stream).poll_write(cx, buf);
if let Poll::Ready(Ok(size)) = poll {
// If write is successful, write to file
let data = &buf[..size];
self_mut.file.write_all(data)?;
}
poll
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.get_mut().stream).poll_flush(cx)
}

fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.get_mut().stream).poll_shutdown(cx)
}
}

#[tokio::main]
async fn main() -> std::io::Result<()> {
init_logger();
let args = Args::parse();

info!("Using the args {:?}", args);
let mut socket = TcpStream::connect(args.ip).await?;
let mut socket = LoggingTcpStream {
stream: TcpStream::connect(args.ip).await?,
file: File::create("/tmp/dump.txt")?,
};

let request = packets::Registration {
password: args.password,
Expand Down
28 changes: 11 additions & 17 deletions zvt/src/feig/sequences.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
use crate::sequences::{read_packet_async, write_with_ack_async, Sequence};
use crate::{packets, ZvtEnum, ZvtParser, ZvtSerializer};
use crate::sequences::DataSource;
use crate::sequences::Sequence;
use crate::{packets, ZvtEnum, ZvtParser};
use anyhow::Result;
use async_stream::try_stream;
use std::boxed::Box;
use std::collections::HashMap;
use std::io::Seek;
use std::io::{Error, ErrorKind};
use std::marker::Unpin;
use std::os::unix::fs::FileExt;
use std::path::{Path, PathBuf};
use std::pin::Pin;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio_stream::Stream;
use zvt_builder::ZVTError;

Expand Down Expand Up @@ -90,15 +89,12 @@ pub enum WriteFileResponse {
}

impl WriteFile {
pub fn into_stream<Source>(
pub fn into_stream(
path: PathBuf,
password: usize,
adpu_size: u32,
src: &mut Source,
) -> Pin<Box<impl Stream<Item = Result<WriteFileResponse>> + '_>>
where
Source: AsyncReadExt + AsyncWriteExt + Unpin,
{
src: &mut DataSource,
) -> Pin<Box<impl Stream<Item = Result<WriteFileResponse>> + '_>> {
// Protocol from the handbook (the numbering is not part of the handbook)
// 1.1 ECR->PT: Send over the list of all files with their sizes.
// 1.2 PT->ECR: Ack
Expand Down Expand Up @@ -133,27 +129,25 @@ impl WriteFile {
};

// 1.1. and 1.2
write_with_ack_async(&packet, &mut src).await?;
src.write_with_ack_async(&packet).await?;
let mut buf = vec![0; adpu_size as usize];
println!("the length is {}", buf.len());

loop {
// Get the data.
let bytes = read_packet_async(&mut src).await?;
let bytes = src.read_packet_async().await?;
println!("The packet is {:?}", bytes);

let response = WriteFileResponse::zvt_parse(&bytes)?;

match response {
WriteFileResponse::CompletionData(_) => {
src.write_all(&packets::Ack {}.zvt_serialize()).await?;

src.write_packet_async(&packets::Ack {}).await?;
yield response;
break;
}
WriteFileResponse::Abort(_) => {
src.write_all(&packets::Ack {}.zvt_serialize()).await?;

src.write_packet_async(&packets::Ack {}).await?;
yield response;
break;
}
Expand Down Expand Up @@ -195,7 +189,7 @@ impl WriteFile {
}),
}),
};
src.write_all(&packet.zvt_serialize()).await?;
src.write_packet_async(&packet).await?;

yield response;
}
Expand Down
73 changes: 60 additions & 13 deletions zvt/src/sequences.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use log::debug;
use std::boxed::Box;
use std::marker::Unpin;
use std::pin::Pin;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};

pub async fn read_packet_async(src: &mut Pin<&mut impl AsyncReadExt>) -> Result<Vec<u8>> {
let mut buf = vec![0; 3];
Expand All @@ -31,6 +31,18 @@ pub async fn read_packet_async(src: &mut Pin<&mut impl AsyncReadExt>) -> Result<
Ok(buf.to_vec())
}

pub async fn write_packet_async<T>(
drain: &mut Pin<&mut impl AsyncWriteExt>,
p: &T,
) -> io::Result<()>
where
T: ZvtSerializer + Sync + Send,
encoding::Default: encoding::Encoding<T>,
{
let bytes = p.zvt_serialize();
drain.write_all(&bytes).await
}

#[derive(ZvtEnum)]
enum Ack {
Ack(packets::Ack),
Expand All @@ -46,14 +58,49 @@ where
{
// We declare the bytes as a separate variable to help the compiler to
// figure out that we can send stuff between threads.
let bytes = p.zvt_serialize();
src.write_all(&bytes).await?;
write_packet_async(src, p).await?;

let bytes = read_packet_async(src).await?;
let _ = Ack::zvt_parse(&bytes)?;
Ok(())
}

trait AsyncReadWrite: AsyncRead + AsyncWrite + Unpin + Sync + Send {}

#[pin_project::pin_project]
pub struct DataSource {
#[pin]
s: Box<dyn AsyncReadWrite>,
}

impl DataSource {
pub async fn read_packet_async(self: Pin<&mut Self>) -> Result<Vec<u8>> {
let this = self.project();
let mut pinned_inner = this.s;
Ok(read_packet_async(&mut pinned_inner).await?)
}

pub async fn write_packet_async<P>(self: Pin<&mut Self>, p: &P) -> Result<()>
where
P: ZvtSerializer + Sync + Send,
encoding::Default: encoding::Encoding<P>,
{
let this = self.project();
let mut pinned_inner = this.s;
Ok(write_packet_async(&mut pinned_inner, p).await?)
}

pub async fn write_with_ack_async<P>(self: Pin<&mut Self>, p: &P) -> Result<()>
where
P: ZvtSerializer + Sync + Send,
encoding::Default: encoding::Encoding<P>,
{
let this = self.project();
let mut pinned_inner = this.s;
Ok(write_with_ack_async(&mut pinned_inner, p).await?)
}
}

/// The trait for converting a sequence into a stream.
///
/// What is written below? The [Self::Input] type must be a command as defined
Expand Down Expand Up @@ -91,7 +138,7 @@ where
let bytes = read_packet_async(&mut src).await?;
let packet = Self::Output::zvt_parse(&bytes)?;
// Write the response.
src.write_all(&packets::Ack {}.zvt_serialize()).await?;
write_packet_async::<packets::Ack>(&mut src, &packets::Ack {}).await?;
yield packet;
};
Box::pin(s)
Expand Down Expand Up @@ -149,7 +196,7 @@ impl Sequence for ReadCard {
let bytes = read_packet_async(&mut src).await?;
let packet = ReadCardResponse::zvt_parse(&bytes)?;
// Write the response.
src.write_all(&packets::Ack {}.zvt_serialize()).await?;
write_packet_async(&mut src, &packets::Ack {}).await?;

match packet {
ReadCardResponse::StatusInformation(_) | ReadCardResponse::Abort(_) => {
Expand Down Expand Up @@ -206,7 +253,7 @@ impl Sequence for Initialization {
let response = InitializationResponse::zvt_parse(&bytes)?;

// Every message requires an Ack.
src.write_all(&packets::Ack {}.zvt_serialize()).await?;
write_packet_async(&mut src, &packets::Ack {}).await?;

match response {
InitializationResponse::CompletionData(_)
Expand Down Expand Up @@ -313,7 +360,7 @@ impl Sequence for Diagnosis {
let response = DiagnosisResponse::zvt_parse(&bytes)?;

// Every message requires an Ack.
src.write_all(&packets::Ack {}.zvt_serialize()).await?;
write_packet_async(&mut src, &packets::Ack {}).await?;

match response {
DiagnosisResponse::CompletionData(_)
Expand Down Expand Up @@ -380,7 +427,7 @@ impl Sequence for EndOfDay {
let packet = EndOfDayResponse::zvt_parse(&bytes)?;

// Write the response.
src.write_all(&packets::Ack {}.zvt_serialize()).await?;
write_packet_async(&mut src, &packets::Ack {}).await?;
match packet {
EndOfDayResponse::CompletionData(_) | EndOfDayResponse::Abort(_) => {
yield packet;
Expand Down Expand Up @@ -445,7 +492,7 @@ impl Sequence for Reservation {
loop {
let bytes = read_packet_async(&mut src).await?;
let packet = AuthorizationResponse::zvt_parse(&bytes)?;
src.write_all(&packets::Ack {}.zvt_serialize()).await?;
write_packet_async(&mut src, &packets::Ack {}).await?;
match packet {
AuthorizationResponse::CompletionData(_) | AuthorizationResponse::Abort(_) => {
yield packet;
Expand Down Expand Up @@ -515,7 +562,7 @@ impl Sequence for PartialReversal {
loop {
let bytes = read_packet_async(&mut src).await?;
let packet = PartialReversalResponse::zvt_parse(&bytes)?;
src.write_all(&packets::Ack {}.zvt_serialize()).await?;
write_packet_async(&mut src, &packets::Ack {}).await?;
match packet {
PartialReversalResponse::CompletionData(_)
| PartialReversalResponse::PartialReversalAbort(_) => {
Expand Down Expand Up @@ -555,7 +602,7 @@ impl Sequence for PreAuthReversal {
loop {
let bytes = read_packet_async(&mut src).await?;
let packet = PartialReversalResponse::zvt_parse(&bytes)?;
src.write_all(&packets::Ack {}.zvt_serialize()).await?;
write_packet_async(&mut src, &packets::Ack {}).await?;
match packet {
PartialReversalResponse::CompletionData(_)
| PartialReversalResponse::PartialReversalAbort(_) => {
Expand Down Expand Up @@ -606,7 +653,7 @@ impl Sequence for PrintSystemConfiguration {
loop {
let bytes = read_packet_async(&mut src).await?;
let packet = PrintSystemConfigurationResponse::zvt_parse(&bytes)?;
src.write_all(&packets::Ack {}.zvt_serialize()).await?;
write_packet_async(&mut src, &packets::Ack {}).await?;
match packet {
PrintSystemConfigurationResponse::CompletionData(_) => {
yield packet;
Expand Down Expand Up @@ -677,7 +724,7 @@ impl Sequence for StatusEnquiry {
loop {
let bytes = read_packet_async(&mut src).await?;
let packet = StatusEnquiryResponse::zvt_parse(&bytes)?;
src.write_all(&packets::Ack {}.zvt_serialize()).await?;
write_packet_async(&mut src, &packets::Ack {}).await?;
match packet {
StatusEnquiryResponse::CompletionData(_) => {
yield packet;
Expand Down
Loading