Skip to content

Commit

Permalink
added weight key trasform, added more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jkrukowski committed Jan 22, 2025
1 parent 47697a0 commit 57ecf70
Show file tree
Hide file tree
Showing 12 changed files with 467 additions and 94 deletions.
60 changes: 57 additions & 3 deletions Package.resolved
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
{
"originHash" : "3173defd78a48faa60b1c56cfa74f15c0c2b63eee978ea01ea5eb21e0b8e5939",
"originHash" : "5019d3e17ad01fb255fb8b5a5956509e8a5e887ebb8a4e735d61ca4c1d79b894",
"pins" : [
{
"identity" : "command",
"kind" : "remoteSourceControl",
"location" : "https://github.com/tuist/Command.git",
"state" : {
"revision" : "6da5edd8893552d45fdafa8545ff7867d17986b4",
"version" : "0.11.16"
}
},
{
"identity" : "jinja",
"kind" : "remoteSourceControl",
Expand All @@ -10,6 +19,24 @@
"version" : "1.0.6"
}
},
{
"identity" : "mockable",
"kind" : "remoteSourceControl",
"location" : "https://github.com/Kolos65/Mockable",
"state" : {
"revision" : "e1b311b01c11415099341eee49769185e965ac4c",
"version" : "0.2.0"
}
},
{
"identity" : "path",
"kind" : "remoteSourceControl",
"location" : "https://github.com/tuist/Path",
"state" : {
"revision" : "7c74ac435e03a927c3a73134c48b61e60221abcb",
"version" : "0.3.8"
}
},
{
"identity" : "swift-argument-parser",
"kind" : "remoteSourceControl",
Expand All @@ -19,6 +46,15 @@
"version" : "1.5.0"
}
},
{
"identity" : "swift-log",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-log",
"state" : {
"revision" : "96a2f8a0fa41e9e09af4585e2724c4e825410b91",
"version" : "1.6.2"
}
},
{
"identity" : "swift-numerics",
"kind" : "remoteSourceControl",
Expand All @@ -33,8 +69,8 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/jkrukowski/swift-safetensors",
"state" : {
"revision" : "f730309020d5b53e137b30bc10eb1d168954e1e7",
"version" : "0.0.6"
"revision" : "718b0f38f912e0bf9d92130fa1e1fe2ae5136dd6",
"version" : "0.0.7"
}
},
{
Expand All @@ -46,6 +82,15 @@
"version" : "0.0.5"
}
},
{
"identity" : "swift-syntax",
"kind" : "remoteSourceControl",
"location" : "https://github.com/swiftlang/swift-syntax",
"state" : {
"revision" : "0687f71944021d616d34d922343dcef086855920",
"version" : "600.0.1"
}
},
{
"identity" : "swift-transformers",
"kind" : "remoteSourceControl",
Expand All @@ -54,6 +99,15 @@
"revision" : "d42fdae473c49ea216671da8caae58e102d28709",
"version" : "0.1.14"
}
},
{
"identity" : "xctest-dynamic-overlay",
"kind" : "remoteSourceControl",
"location" : "https://github.com/pointfreeco/xctest-dynamic-overlay",
"state" : {
"revision" : "a3f634d1a409c7979cabc0a71b3f26ffa9fc8af1",
"version" : "1.4.3"
}
}
],
"version" : 3
Expand Down
17 changes: 16 additions & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ let package = Package(
),
.package(
url: "https://github.com/jkrukowski/swift-safetensors.git",
from: "0.0.6"
from: "0.0.7"
),
.package(
url: "https://github.com/apple/swift-argument-parser.git",
Expand All @@ -44,6 +44,10 @@ let package = Package(
url: "https://github.com/jkrukowski/swift-sentencepiece",
from: "0.0.5"
),
.package(
url: "https://github.com/tuist/Command.git",
from: "0.11.16"
),
],
targets: [
.executableTarget(
Expand Down Expand Up @@ -84,6 +88,17 @@ let package = Package(
.copy("Resources")
]
),
.testTarget(
name: "AccuracyTests",
dependencies: [
"Embeddings",
"TestingUtils",
.product(name: "Command", package: "Command"),
],
resources: [
.copy("Scripts")
]
),
.testTarget(
name: "MLTensorUtilsTests",
dependencies: [
Expand Down
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ Some of the supported models on `Hugging Face`:
- [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)
- [sentence-transformers/msmarco-bert-base-dot-v5](https://huggingface.co/sentence-transformers/msmarco-bert-base-dot-v5)
- [thenlper/gte-base](https://huggingface.co/thenlper/gte-base)
- [google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased)

NOTE: `google-bert/bert-base-uncased` is supported but `weightKeyTransform` must be provided:

```swift
let modelBundle = try await Bert.loadModelBundle(from: modelId, weightKeyTransform: Bert.googleWeightsKeyTransform)
```

### XLM-RoBERTa (Cross-lingual Language Model - Robustly Optimized BERT Approach)

Expand Down
91 changes: 63 additions & 28 deletions Sources/Embeddings/Bert/BertUtils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,56 +15,78 @@ extension Bert {
public static func loadModelBundle(
from hubRepoId: String,
downloadBase: URL? = nil,
useBackgroundSession: Bool = false
useBackgroundSession: Bool = false,
weightKeyTransform: ((String) -> String) = { $0 }
) async throws -> Bert.ModelBundle {
let modelFolder = try await downloadModelFromHub(
from: hubRepoId,
downloadBase: downloadBase,
useBackgroundSession: useBackgroundSession
)
return try await loadModelBundle(from: modelFolder)
return try await loadModelBundle(
from: modelFolder,
weightKeyTransform: weightKeyTransform
)
}

public static func loadModelBundle(from modelFolder: URL) async throws -> Bert.ModelBundle {
public static func loadModelBundle(
from modelFolder: URL,
weightKeyTransform: ((String) -> String) = { $0 }
) async throws -> Bert.ModelBundle {
let tokenizer = try await AutoTokenizer.from(modelFolder: modelFolder)
// NOTE: just `safetensors` support for now
let weightsUrl = modelFolder.appendingPathComponent("model.safetensors")
let configUrl = modelFolder.appendingPathComponent("config.json")
let config = try Bert.loadConfig(at: configUrl)
let model = try Bert.loadModel(weightsUrl: weightsUrl, config: config)
let model = try Bert.loadModel(
weightsUrl: weightsUrl,
config: config,
weightKeyTransform: weightKeyTransform
)
return Bert.ModelBundle(model: model, tokenizer: TokenizerWrapper(tokenizer))
}
}

extension Bert {
// NOTE: this is a simple key transformation that is required for the Google BERT weights.
// Model available here: [google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased)
public static func googleWeightsKeyTransform(_ key: String) -> String {
"bert.\(key)"
.replace(suffix: ".LayerNorm.weight", with: ".LayerNorm.gamma")
.replace(suffix: ".LayerNorm.bias", with: ".LayerNorm.beta")
}
}

extension Bert {
public static func loadModel(
weightsUrl: URL,
config: Bert.ModelConfig
config: Bert.ModelConfig,
weightKeyTransform: ((String) -> String) = { $0 }
) throws -> Bert.Model {
// NOTE: just `safetensors` support for now
let safetensors = try Safetensors.read(at: weightsUrl)
let pooler = try Bert.Pooler(
dense: MLTensorUtils.linear(
weight: safetensors.mlTensor(forKey: "pooler.dense.weight"),
bias: safetensors.mlTensor(forKey: "pooler.dense.bias")))
weight: safetensors.mlTensor(forKey: weightKeyTransform("pooler.dense.weight")),
bias: safetensors.mlTensor(forKey: weightKeyTransform("pooler.dense.bias"))))

let wordEmbeddings = try MLTensorUtils.embedding(
weight: safetensors.mlTensor(
forKey: "embeddings.word_embeddings.weight"))
forKey: weightKeyTransform("embeddings.word_embeddings.weight")))

let tokenTypeEmbeddings = try MLTensorUtils.embedding(
weight: safetensors.mlTensor(
forKey: "embeddings.token_type_embeddings.weight"))
forKey: weightKeyTransform("embeddings.token_type_embeddings.weight")))

let positionEmbeddings = try MLTensorUtils.embedding(
weight: safetensors.mlTensor(
forKey: "embeddings.position_embeddings.weight"))
forKey: weightKeyTransform("embeddings.position_embeddings.weight")))

let layerNorm = try MLTensorUtils.layerNorm(
weight: safetensors.mlTensor(
forKey: "embeddings.LayerNorm.weight"),
forKey: weightKeyTransform("embeddings.LayerNorm.weight")),
bias: safetensors.mlTensor(
forKey: "embeddings.LayerNorm.bias"),
forKey: weightKeyTransform("embeddings.LayerNorm.bias")),
epsilon: config.layerNormEps)

let embeddings = Bert.Embeddings(
Expand All @@ -78,19 +100,25 @@ extension Bert {
let bertSelfAttention = try Bert.SelfAttention(
query: MLTensorUtils.linear(
weight: safetensors.mlTensor(
forKey: "encoder.layer.\(layer).attention.self.query.weight"),
forKey: weightKeyTransform(
"encoder.layer.\(layer).attention.self.query.weight")),
bias: safetensors.mlTensor(
forKey: "encoder.layer.\(layer).attention.self.query.bias")),
forKey: weightKeyTransform(
"encoder.layer.\(layer).attention.self.query.bias"))),
key: MLTensorUtils.linear(
weight: safetensors.mlTensor(
forKey: "encoder.layer.\(layer).attention.self.key.weight"),
forKey: weightKeyTransform(
"encoder.layer.\(layer).attention.self.key.weight")),
bias: safetensors.mlTensor(
forKey: "encoder.layer.\(layer).attention.self.key.bias")),
forKey: weightKeyTransform("encoder.layer.\(layer).attention.self.key.bias")
)),
value: MLTensorUtils.linear(
weight: safetensors.mlTensor(
forKey: "encoder.layer.\(layer).attention.self.value.weight"),
forKey: weightKeyTransform(
"encoder.layer.\(layer).attention.self.value.weight")),
bias: safetensors.mlTensor(
forKey: "encoder.layer.\(layer).attention.self.value.bias")),
forKey: weightKeyTransform(
"encoder.layer.\(layer).attention.self.value.bias"))),
numAttentionHeads: config.numAttentionHeads,
attentionHeadSize: config.hiddenSize / config.numAttentionHeads,
allHeadSize: config.numAttentionHeads
Expand All @@ -99,14 +127,18 @@ extension Bert {
let bertSelfOutput = try Bert.SelfOutput(
dense: MLTensorUtils.linear(
weight: safetensors.mlTensor(
forKey: "encoder.layer.\(layer).attention.output.dense.weight"),
forKey: weightKeyTransform(
"encoder.layer.\(layer).attention.output.dense.weight")),
bias: safetensors.mlTensor(
forKey: "encoder.layer.\(layer).attention.output.dense.bias")),
forKey: weightKeyTransform(
"encoder.layer.\(layer).attention.output.dense.bias"))),
layerNorm: MLTensorUtils.layerNorm(
weight: safetensors.mlTensor(
forKey: "encoder.layer.\(layer).attention.output.LayerNorm.weight"),
forKey: weightKeyTransform(
"encoder.layer.\(layer).attention.output.LayerNorm.weight")),
bias: safetensors.mlTensor(
forKey: "encoder.layer.\(layer).attention.output.LayerNorm.bias"),
forKey: weightKeyTransform(
"encoder.layer.\(layer).attention.output.LayerNorm.bias")),
epsilon: config.layerNormEps)
)
let bertAttention = Bert.Attention(
Expand All @@ -116,21 +148,24 @@ extension Bert {
let bertIntermediate = try Bert.Intermediate(
dense: MLTensorUtils.linear(
weight: safetensors.mlTensor(
forKey: "encoder.layer.\(layer).intermediate.dense.weight"),
forKey: weightKeyTransform(
"encoder.layer.\(layer).intermediate.dense.weight")),
bias: safetensors.mlTensor(
forKey: "encoder.layer.\(layer).intermediate.dense.bias"))
forKey: weightKeyTransform("encoder.layer.\(layer).intermediate.dense.bias")
))
)
let bertOutput = try Bert.Output(
dense: MLTensorUtils.linear(
weight: safetensors.mlTensor(
forKey: "encoder.layer.\(layer).output.dense.weight"),
forKey: weightKeyTransform("encoder.layer.\(layer).output.dense.weight")),
bias: safetensors.mlTensor(
forKey: "encoder.layer.\(layer).output.dense.bias")),
forKey: weightKeyTransform("encoder.layer.\(layer).output.dense.bias"))),
layerNorm: MLTensorUtils.layerNorm(
weight: safetensors.mlTensor(
forKey: "encoder.layer.\(layer).output.LayerNorm.weight"),
forKey: weightKeyTransform("encoder.layer.\(layer).output.LayerNorm.weight")
),
bias: safetensors.mlTensor(
forKey: "encoder.layer.\(layer).output.LayerNorm.bias"),
forKey: weightKeyTransform("encoder.layer.\(layer).output.LayerNorm.bias")),
epsilon: config.layerNormEps))

let bertLayer = Bert.Layer(
Expand Down
Loading

0 comments on commit 57ecf70

Please sign in to comment.