Skip to content

Commit

Permalink
add feature to observe download progress
Browse files Browse the repository at this point in the history
  • Loading branch information
eastriverlee committed Jan 30, 2024
1 parent 9929934 commit 7860485
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 16 deletions.
24 changes: 16 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,27 +61,35 @@ struct ContentView: View {

```swift
class Bot: LLM {
convenience init?() async {
convenience init?(_ update: @escaping (Double) -> Void) async {
let systemPrompt = "You are a sentient AI with emotions."
let model = HuggingFaceModel("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", .Q2_K, template: .chatML(systemPrompt))
try? await self.init(from: model)
try? await self.init(from: model) { progress in update(progress) }
}
}

...

struct ContentView: View {
@State var bot: Bot? = nil
@State var progress: CGFloat = 0
func updateProgress(_ progress: Double) {
Task { await MainActor.run { self.progress = CGFloat(progress) } }
}
var body: some View {
if let bot {
BotView(bot)
} else {
ProgressView().padding()
Text("(loading huggingface model...)").opacity(0.2)
.onAppear() { Task {
let bot = await Bot()
await MainActor.run { self.bot = bot }
} }
ProgressView(value: progress) {
Text("loading huggingface model...")
} currentValueLabel: {
Text(String(format: "%.2f%%", progress * 100))
}
.padding()
.onAppear() { Task {
let bot = await Bot(updateProgress)
await MainActor.run { self.bot = bot }
} }
}
}
}
Expand Down
35 changes: 29 additions & 6 deletions Sources/LLM/LLM.swift
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ open class LLM: ObservableObject {
private var stopSequenceLength: Int
private var params: llama_context_params
private var isFull = false
private var updateProgress: (Double) -> Void = { _ in }

public init(
from path: String,
Expand Down Expand Up @@ -131,9 +132,12 @@ open class LLM: ObservableObject {
topP: Float = 0.95,
temp: Float = 0.8,
historyLimit: Int = 8,
maxTokenCount: Int32 = 2048
maxTokenCount: Int32 = 2048,
updateProgress: @escaping (Double) -> Void = { print(String(format: "downloaded(%.2f%%)", $0 * 100)) }
) async throws {
let url = try await huggingFaceModel.download(to: url, as: name)
let url = try await huggingFaceModel.download(to: url, as: name) { progress in
Task { await updateProgress(progress) }
}
self.init(
from: url,
template: huggingFaceModel.template,
Expand All @@ -145,6 +149,7 @@ open class LLM: ObservableObject {
historyLimit: historyLimit,
maxTokenCount: maxTokenCount
)
self.updateProgress = updateProgress
}

public convenience init(
Expand Down Expand Up @@ -578,6 +583,7 @@ public enum Quantization: String {
public enum HuggingFaceError: Error {
case network(statusCode: Int)
case noFilteredURL
case urlIsNilForSomeReason
}

public struct HuggingFaceModel {
Expand Down Expand Up @@ -616,17 +622,16 @@ public struct HuggingFaceModel {
return nil
}

public func download(to directory: URL = .documentsDirectory, as name: String? = nil) async throws -> URL {
public func download(to directory: URL = .documentsDirectory, as name: String? = nil, _ updateProgress: @escaping (Double) -> Void) async throws -> URL {
var destination: URL
if let name {
destination = directory.appending(path: name)
guard !destination.exists else { return destination }
guard !destination.exists else { updateProgress(1); return destination }
}
guard let downloadURL = try await getDownloadURL() else { throw HuggingFaceError.noFilteredURL }
destination = directory.appending(path: downloadURL.lastPathComponent)
guard !destination.exists else { return destination }
let data = try await downloadURL.getData()
try data.write(to: destination)
try await downloadURL.downloadData(to: destination, updateProgress)
return destination
}

Expand All @@ -651,6 +656,24 @@ extension URL {
guard statusCode / 100 == 2 else { throw HuggingFaceError.network(statusCode: statusCode) }
return data
}
fileprivate func downloadData(to destination: URL, _ updateProgress: @escaping (Double) -> Void) async throws {
var observation: NSKeyValueObservation!
let url: URL = try await withCheckedThrowingContinuation { continuation in
let task = URLSession.shared.downloadTask(with: self) { url, response, error in
if let error { return continuation.resume(throwing: error) }
guard let url else { return continuation.resume(throwing: HuggingFaceError.urlIsNilForSomeReason) }
let statusCode = (response as! HTTPURLResponse).statusCode
guard statusCode / 100 == 2 else { return continuation.resume(throwing: HuggingFaceError.network(statusCode: statusCode)) }
continuation.resume(returning: url)
}
observation = task.progress.observe(\.fractionCompleted) { progress, _ in
updateProgress(progress.fractionCompleted)
}
task.resume()
}
let _ = observation
try FileManager.default.moveItem(at: url, to: destination)
}
}

package extension String {
Expand Down
4 changes: 2 additions & 2 deletions Tests/LLMTests/LLMTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -207,14 +207,14 @@ final class LLMTests: XCTestCase {
}

func testInferenceFromHuggingFaceModel() async throws {
var bot = try await LLM(from: model)
let bot = try await LLM(from: model)
let input = "have you heard of this so-called LLM.swift library?"
await bot.respond(to: input)
#assert(!bot.output.isEmpty)
}

func testRecoveryFromLengtyInput() async throws {
var bot = try await LLM(from: model, maxTokenCount: 16)
let bot = try await LLM(from: model, maxTokenCount: 16)
let input = "have you heard of this so-called LLM.swift library?"
await bot.respond(to: input)
#assert(bot.output == "tl;dr")
Expand Down

0 comments on commit 7860485

Please sign in to comment.