diff --git a/Sources/LLM/LLM.swift b/Sources/LLM/LLM.swift index 329cfc2..f5bb0ee 100644 --- a/Sources/LLM/LLM.swift +++ b/Sources/LLM/LLM.swift @@ -205,6 +205,10 @@ open class LLM: ObservableObject { private var currentCount: Int32! private var decoded = "" + open func recoverFromLengthy(_ input: borrowing String, to output: borrowing AsyncStream.Continuation) { + output.yield("tl;dr") + } + private func prepare(from input: borrowing String, to output: borrowing AsyncStream.Continuation) -> Bool { guard !input.isEmpty else { return false } context = .init(model, params) @@ -212,16 +216,17 @@ open class LLM: ObservableObject { var initialCount = tokens.count currentCount = Int32(initialCount) if maxTokenCount <= currentCount { - if history.isEmpty { - isFull = true - output.yield("Input is too long.") - return false - } else { + while !history.isEmpty { history.removeFirst(min(2, history.count)) tokens = encode(preProcess(self.input, history)) initialCount = tokens.count currentCount = Int32(initialCount) } + if maxTokenCount <= currentCount { + isFull = true + recoverFromLengthy(input, to: output) + return false + } } for (i, token) in tokens.enumerated() { batch.n_tokens = Int32(i) diff --git a/Tests/LLMTests/LLMTests.swift b/Tests/LLMTests/LLMTests.swift index 6599e0e..0ec93e4 100644 --- a/Tests/LLMTests/LLMTests.swift +++ b/Tests/LLMTests/LLMTests.swift @@ -212,4 +212,11 @@ final class LLMTests: XCTestCase { await bot.respond(to: input) #assert(!bot.output.isEmpty) } + + func testRecoveryFromLengtyInput() async throws { + var bot = try await LLM(from: model, maxTokenCount: 16) + let input = "have you heard of this so-called LLM.swift library?" + await bot.respond(to: input) + #assert(bot.output == "tl;dr") + } }