Skip to content

Commit

Permalink
FIX: Fix use of symbol_cstr_len in MetadataEncoder
Browse files Browse the repository at this point in the history
  • Loading branch information
threecgreen committed Jan 8, 2024
1 parent 3e9b4c5 commit aede39b
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 66 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
`SymbolMappingMsgV1`
- Fixed cases where `dbn` CLI tool would write a broken pipe error to standard error
such as when piping to `head`
- Fixed bug in sync and async `MetadataEncoder`s where `version` was used to determine
the encoded length of fixed-length symbols instead of the `symbol_cstr_len` field

## 0.14.2 - 2023-11-17
### Enhancements
Expand Down
61 changes: 27 additions & 34 deletions rust/dbn/src/encode/dbn/async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ where
.write_u32_le(length)
.await
.map_err(metadata_err)?;
self.encode_fixed_len_cstr::<{ crate::METADATA_DATASET_CSTR_LEN }>(&metadata.dataset)
self.encode_fixed_len_cstr(crate::METADATA_DATASET_CSTR_LEN, &metadata.dataset)
.await?;
self.writer
.write_u16_le(metadata.schema.map(|s| s as u16).unwrap_or(NULL_SCHEMA))
Expand Down Expand Up @@ -279,13 +279,13 @@ where
.map_err(metadata_err)?;
// schema_definition_length
self.writer.write_u32_le(0).await.map_err(metadata_err)?;
self.encode_repeated_symbol_cstr(metadata.version, &metadata.symbols)
self.encode_repeated_symbol_cstr(metadata.symbol_cstr_len, &metadata.symbols)
.await?;
self.encode_repeated_symbol_cstr(metadata.version, &metadata.partial)
self.encode_repeated_symbol_cstr(metadata.symbol_cstr_len, &metadata.partial)
.await?;
self.encode_repeated_symbol_cstr(metadata.version, &metadata.not_found)
self.encode_repeated_symbol_cstr(metadata.symbol_cstr_len, &metadata.not_found)
.await?;
self.encode_symbol_mappings(metadata.version, &metadata.mappings)
self.encode_symbol_mappings(metadata.symbol_cstr_len, &metadata.mappings)
.await?;

Ok(())
Expand Down Expand Up @@ -331,53 +331,45 @@ where
Ok(())
}

async fn encode_repeated_symbol_cstr(&mut self, version: u8, symbols: &[String]) -> Result<()> {
async fn encode_repeated_symbol_cstr(
&mut self,
symbol_cstr_len: usize,
symbols: &[String],
) -> Result<()> {
self.writer
.write_u32_le(symbols.len() as u32)
.await
.map_err(|e| Error::io(e, "writing cstr length"))?;
.map_err(|e| Error::io(e, "writing repeated symbols length"))?;
for symbol in symbols {
if version == 1 {
self.encode_fixed_len_cstr::<{ crate::compat::SYMBOL_CSTR_LEN_V1 }>(symbol)
.await
} else {
self.encode_fixed_len_cstr::<{ crate::SYMBOL_CSTR_LEN }>(symbol)
.await
}?;
self.encode_fixed_len_cstr(symbol_cstr_len, symbol).await?;
}

Ok(())
}

async fn encode_symbol_mappings(
&mut self,
version: u8,
symbol_cstr_len: usize,
symbol_mappings: &[SymbolMapping],
) -> Result<()> {
// encode mappings_count
self.writer
.write_u32_le(symbol_mappings.len() as u32)
.await
.map_err(|e| Error::io(e, "writing symbol mappings length"))?;
if version == 1 {
for symbol_mapping in symbol_mappings {
self.encode_symbol_mapping::<{ crate::compat::SYMBOL_CSTR_LEN_V1 }>(symbol_mapping)
.await?;
}
} else {
for symbol_mapping in symbol_mappings {
self.encode_symbol_mapping::<{ crate::SYMBOL_CSTR_LEN }>(symbol_mapping)
.await?;
}
for symbol_mapping in symbol_mappings {
self.encode_symbol_mapping(symbol_cstr_len, symbol_mapping)
.await?;
}
Ok(())
}

async fn encode_symbol_mapping<const LEN: usize>(
async fn encode_symbol_mapping(
&mut self,
symbol_cstr_len: usize,
symbol_mapping: &SymbolMapping,
) -> Result<()> {
self.encode_fixed_len_cstr::<LEN>(&symbol_mapping.raw_symbol)
self.encode_fixed_len_cstr(symbol_cstr_len, &symbol_mapping.raw_symbol)
.await?;
// encode interval_count
self.writer
Expand All @@ -391,22 +383,23 @@ where
self.encode_date(interval.end_date)
.await
.map_err(|e| Error::io(e, "writing end date"))?;
self.encode_fixed_len_cstr::<LEN>(&interval.symbol).await?;
self.encode_fixed_len_cstr(symbol_cstr_len, &interval.symbol)
.await?;
}
Ok(())
}

async fn encode_fixed_len_cstr<const LEN: usize>(&mut self, string: &str) -> Result<()> {
async fn encode_fixed_len_cstr(&mut self, symbol_cstr_len: usize, string: &str) -> Result<()> {
if !string.is_ascii() {
return Err(Error::Conversion {
input: string.to_owned(),
desired_type: "ASCII",
});
}
if string.len() > LEN {
if string.len() > symbol_cstr_len {
return Err(Error::encode(
format!(
"'{string}' is too long to be encoded in DBN; it cannot be longer than {LEN} characters"
"'{string}' is too long to be encoded in DBN; it cannot be longer than {symbol_cstr_len} characters"
)));
}
let cstr_err = |e| Error::io(e, "writing cstr");
Expand All @@ -415,7 +408,7 @@ where
.await
.map_err(cstr_err)?;
// pad remaining space with null bytes
for _ in string.len()..LEN {
for _ in string.len()..symbol_cstr_len {
self.writer.write_u8(0).await.map_err(cstr_err)?;
}
Ok(())
Expand Down Expand Up @@ -533,7 +526,7 @@ mod tests {
"LNQ".to_owned(),
];
target
.encode_repeated_symbol_cstr(crate::DBN_VERSION, symbols.as_slice())
.encode_repeated_symbol_cstr(crate::SYMBOL_CSTR_LEN, symbols.as_slice())
.await
.unwrap();
assert_eq!(
Expand All @@ -555,7 +548,7 @@ mod tests {
let mut buffer = Vec::new();
let mut target = MetadataEncoder::new(&mut buffer);
target
.encode_fixed_len_cstr::<{ crate::SYMBOL_CSTR_LEN }>("NG")
.encode_fixed_len_cstr(crate::SYMBOL_CSTR_LEN, "NG")
.await
.unwrap();
assert_eq!(buffer.len(), crate::SYMBOL_CSTR_LEN);
Expand Down
57 changes: 25 additions & 32 deletions rust/dbn/src/encode/dbn/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ where
self.writer
.write_all(length.to_le_bytes().as_slice())
.map_err(metadata_err)?;
self.encode_fixed_len_cstr::<{ crate::METADATA_DATASET_CSTR_LEN }>(&metadata.dataset)?;
self.encode_fixed_len_cstr(crate::METADATA_DATASET_CSTR_LEN, &metadata.dataset)?;
self.writer
.write_all(
(metadata.schema.map(|s| s as u16).unwrap_or(NULL_SCHEMA))
Expand Down Expand Up @@ -177,10 +177,10 @@ where
.write_all(0u32.to_le_bytes().as_slice())
.map_err(metadata_err)?;

self.encode_repeated_symbol_cstr(metadata.version, metadata.symbols.as_slice())?;
self.encode_repeated_symbol_cstr(metadata.version, metadata.partial.as_slice())?;
self.encode_repeated_symbol_cstr(metadata.version, metadata.not_found.as_slice())?;
self.encode_symbol_mappings(metadata.version, metadata.mappings.as_slice())?;
self.encode_repeated_symbol_cstr(metadata.symbol_cstr_len, metadata.symbols.as_slice())?;
self.encode_repeated_symbol_cstr(metadata.symbol_cstr_len, metadata.partial.as_slice())?;
self.encode_repeated_symbol_cstr(metadata.symbol_cstr_len, metadata.not_found.as_slice())?;
self.encode_symbol_mappings(metadata.symbol_cstr_len, metadata.mappings.as_slice())?;

Ok(())
}
Expand Down Expand Up @@ -244,49 +244,42 @@ where
Ok(())
}

fn encode_repeated_symbol_cstr(&mut self, version: u8, symbols: &[String]) -> Result<()> {
fn encode_repeated_symbol_cstr(
&mut self,
symbol_cstr_len: usize,
symbols: &[String],
) -> Result<()> {
self.writer
.write_all((symbols.len() as u32).to_le_bytes().as_slice())
.map_err(|e| Error::io(e, "writing cstr length"))?;
.map_err(|e| Error::io(e, "writing repeated symbols length"))?;
for symbol in symbols {
if version == 1 {
self.encode_fixed_len_cstr::<{ crate::compat::SYMBOL_CSTR_LEN_V1 }>(symbol)?;
} else {
self.encode_fixed_len_cstr::<{ crate::SYMBOL_CSTR_LEN }>(symbol)?;
}
self.encode_fixed_len_cstr(symbol_cstr_len, symbol)?;
}

Ok(())
}

fn encode_symbol_mappings(
&mut self,
version: u8,
symbol_cstr_len: usize,
symbol_mappings: &[SymbolMapping],
) -> Result<()> {
// encode mappings_count
self.writer
.write_all((symbol_mappings.len() as u32).to_le_bytes().as_slice())
.map_err(|e| Error::io(e, "writing symbol mappings length"))?;
if version == 1 {
for symbol_mapping in symbol_mappings {
self.encode_symbol_mapping::<{ crate::compat::SYMBOL_CSTR_LEN_V1 }>(
symbol_mapping,
)?;
}
} else {
for symbol_mapping in symbol_mappings {
self.encode_symbol_mapping::<{ crate::SYMBOL_CSTR_LEN }>(symbol_mapping)?;
}
for symbol_mapping in symbol_mappings {
self.encode_symbol_mapping(symbol_cstr_len, symbol_mapping)?;
}
Ok(())
}

fn encode_symbol_mapping<const LEN: usize>(
fn encode_symbol_mapping(
&mut self,
symbol_cstr_len: usize,
symbol_mapping: &SymbolMapping,
) -> Result<()> {
self.encode_fixed_len_cstr::<LEN>(&symbol_mapping.raw_symbol)?;
self.encode_fixed_len_cstr(symbol_cstr_len, &symbol_mapping.raw_symbol)?;
// encode interval_count
self.writer
.write_all(
Expand All @@ -300,28 +293,28 @@ where
.map_err(|e| Error::io(e, "writing start date"))?;
self.encode_date(interval.end_date)
.map_err(|e| Error::io(e, "writing end date"))?;
self.encode_fixed_len_cstr::<LEN>(&interval.symbol)?;
self.encode_fixed_len_cstr(symbol_cstr_len, &interval.symbol)?;
}
Ok(())
}

fn encode_fixed_len_cstr<const LEN: usize>(&mut self, string: &str) -> Result<()> {
fn encode_fixed_len_cstr(&mut self, symbol_cstr_len: usize, string: &str) -> Result<()> {
if !string.is_ascii() {
return Err(Error::Conversion {
input: string.to_owned(),
desired_type: "ASCII",
});
}
if string.len() > LEN {
if string.len() > symbol_cstr_len {
return Err(Error::encode(
format!(
"'{string}' is too long to be encoded in DBN; it cannot be longer than {LEN} characters"
"'{string}' is too long to be encoded in DBN; it cannot be longer than {symbol_cstr_len} characters"
)));
}
let cstr_err = |e| Error::io(e, "writing cstr");
self.writer.write_all(string.as_bytes()).map_err(cstr_err)?;
// pad remaining space with null bytes
for _ in string.len()..LEN {
for _ in string.len()..symbol_cstr_len {
self.writer.write_all(&[0]).map_err(cstr_err)?;
}
Ok(())
Expand Down Expand Up @@ -529,7 +522,7 @@ mod tests {
"LNQ".to_owned(),
];
target
.encode_repeated_symbol_cstr(crate::DBN_VERSION, symbols.as_slice())
.encode_repeated_symbol_cstr(crate::SYMBOL_CSTR_LEN, symbols.as_slice())
.unwrap();
assert_eq!(
buffer.len(),
Expand All @@ -550,7 +543,7 @@ mod tests {
let mut buffer = Vec::new();
let mut target = MetadataEncoder::new(&mut buffer);
target
.encode_fixed_len_cstr::<{ crate::SYMBOL_CSTR_LEN }>("NG")
.encode_fixed_len_cstr(crate::SYMBOL_CSTR_LEN, "NG")
.unwrap();
assert_eq!(buffer.len(), crate::SYMBOL_CSTR_LEN);
assert_eq!(&buffer[..2], b"NG");
Expand Down

0 comments on commit aede39b

Please sign in to comment.