diff --git a/src/openai/pipelines/llm_engine.rs b/src/openai/pipelines/llm_engine.rs index 4cffbdd..0ba766c 100644 --- a/src/openai/pipelines/llm_engine.rs +++ b/src/openai/pipelines/llm_engine.rs @@ -370,7 +370,12 @@ impl<'a> LLMEngine<'a> { } let block_number = if i / self.cache_config.block_size >= table.len() { - table.get(table.len() - 1).unwrap() //position exceed! use last position + panic!( + "Block table is too small (prompt)! i={} block_size={} table_len={}", + i, + self.cache_config.block_size, + table.len() + ); } else { table.get(i / self.cache_config.block_size).unwrap() }; @@ -460,7 +465,7 @@ impl<'a> LLMEngine<'a> { .collect::>(); let block_number = if position / self.cache_config.block_size >= table.len() { - table.get(table.len() - 1).unwrap() //position exceed! use last position; TODO (bug fix) + panic!("Block table is too small (completion)! start_pos={} block_size={} table_len={}", position, self.cache_config.block_size, table.len()); } else { table.get(position / self.cache_config.block_size).unwrap() }; diff --git a/src/scheduler/block_engine.rs b/src/scheduler/block_engine.rs index 1db9507..a9ac7a3 100644 --- a/src/scheduler/block_engine.rs +++ b/src/scheduler/block_engine.rs @@ -27,6 +27,10 @@ impl LogicalTokenBlock { self.num_tokens == self.block_size } + pub fn is_empty(&self) -> bool { + self.num_tokens == 0 + } + pub fn append_token_id(&mut self, token: usize) { assert!(!self.is_full()); self.tokens[self.num_tokens] = token; diff --git a/src/scheduler/sequence.rs b/src/scheduler/sequence.rs index 6e2db88..b1fc0db 100644 --- a/src/scheduler/sequence.rs +++ b/src/scheduler/sequence.rs @@ -76,7 +76,7 @@ impl _Sequence { pub fn blocks_to_add_new_tok(&self) -> usize { let last = self.logical_token_blocks.last(); - if !last.is_some_and(|last| last.is_full()) { + if !last.is_some_and(|last| last.is_full() || last.is_empty()) { // If we have space 0 } else { @@ -168,17 +168,22 @@ impl _Sequence { fn append_token_to_blocks(&mut self, token: usize) { let last = self.logical_token_blocks.last_mut(); - if last.is_some() && !last.as_ref().is_some_and(|last| last.is_full()) { - // If we have space - let last = last.unwrap(); - last.append_token_id(token); - } else { - self.logical_token_blocks - .push(LogicalTokenBlock::new(self.block_size)); + match last { + Some(last) => { + last.append_token_id(tok); + } + None => { + self.logical_token_blocks + .push(LogicalTokenBlock::new(self.block_size)); + self.logical_token_blocks + .last_mut() + .unwrap() + .append_token_id(tok); + } + } + if self.logical_token_blocks.last().as_ref().unwrap().is_full() { self.logical_token_blocks - .last_mut() - .unwrap() - .append_token_id(token); + .push(LogicalTokenBlock::new(*block_size)); } } }