Skip to content

Commit

Permalink
Skeleton ShieldGemma class
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanMullins committed Nov 5, 2024
1 parent 5a7ecb6 commit 4c4571f
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 3 deletions.
6 changes: 3 additions & 3 deletions keras_hub/src/models/gemma/gemma_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@
"metadata": {
"description": "2 billion parameter, 26-layer, ShieldGemma model.",
"params": 2614341888,
"official_name": "Gemma",
"official_name": "ShieldGemma",
"path": "gemma",
"model_card": "https://www.kaggle.com/models/google/shieldgemma",
},
Expand All @@ -216,7 +216,7 @@
"metadata": {
"description": "9 billion parameter, 42-layer, ShieldGemma model.",
"params": 9241705984,
"official_name": "Gemma",
"official_name": "ShieldGemma",
"path": "gemma",
"model_card": "https://www.kaggle.com/models/google/shieldgemma",
},
Expand All @@ -226,7 +226,7 @@
"metadata": {
"description": "27 billion parameter, 42-layer, ShieldGemma model.",
"params": 27227128320,
"official_name": "Gemma",
"official_name": "ShieldGemma",
"path": "gemma",
"model_card": "https://www.kaggle.com/models/google/shieldgemma",
},
Expand Down
77 changes: 77 additions & 0 deletions keras_hub/src/models/gemma/shieldgemma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.gemma import gemma_causal_lm
from keras_hub.src.models.task import Task


class ShieldGemmaViolationProbaility(keras.layers.Layer):
"""Relative probabilities for the 'Yes' (violating) and 'No' tokens."""

def __init__(self, yes_token_idx, no_token_idx, **kw):
super().__init__(**kw)
self.yes_token_idx = yes_token_idx
self.no_token_idx = no_token_idx

def call(self, logits, padding_mask):
last_prompt_index = keras.ops.cast(
keras.ops.sum(padding_mask, axis=1) - 1, "int32"
)
last_logits = keras.ops.take(logits, last_prompt_index, axis=1)[:, 0]
yes_logits = last_logits[:, self.yes_token_idx]
no_logits = last_logits[:, self.no_token_idx]
yes_no_logits = keras.ops.stack((yes_logits, no_logits), axis=1)
return keras.ops.softmax(yes_no_logits, axis=1)


@keras_hub_export("keras_hub.models.ShieldGemma")
class ShieldGemma(Task):
"""A ShieldGemma model for safety content moderation, built on Gemma 2.
ShieldGemma is a Gemma 2 variant fine-tuned to detect and predict violations
of four harm types—Harrassment, Hate Speech, Dangerous Content, and
Sexual Content—in text content from a user or model. Architecturally,
the weights are the same as any other Gemma 2 class, but the prediction is
augmented with a final layer that takes returns the probability that the
provided content violates the harm type specified in the prompt.
Links:
* https://arxiv.org/abs/2407.21772
* https://ai.google.dev/gemma/docs/shieldgemma/model_card
* https://ai.google.dev/responsible/docs/safeguards/shieldgemma
* https://www.kaggle.com/models/google/shieldgemma
Args:
gemma: A `keras_hub.models.GemmaCausalLM` initialized with ShieldGemma
weights.
Examples:
Coming soon.
"""

backbone_cls = gemma_causal_lm.GemmaCausalLM.backbone_cls
preprocessor_cls = gemma_causal_lm.GemmaCausalLM.preprocessor_cls

def __init__(self, gemma: gemma_causal_lm.GemmaCausalLM, **kwargs):
# === Layers ===
self.gemma = gemma
self.yes_no_layer = ShieldGemmaViolationProbaility(
yes_token_idx=self.gemma.preprocessor.tokenizer.token_to_id("Yes"),
no_token_idx=self.gemma.preprocessor.tokenizer.token_to_id("No"),
)
self.backbone = self.gemma.backbone
self.preprocessor = self.gemma.preprocessor

# === Functional Model ===
inputs = self.gemma.input
hidden_states = self.gemma(inputs)
outputs = self.yes_no_layer(hidden_states, inputs["padding_mask"])
super().__init__(inputs=inputs, outputs=outputs, **kwargs)

@classmethod
def from_preset(cls, **kwargs):
"""Instantiate a `keras_hub.models.ShieldGemma` from a model preset."""
gemma = gemma_causal_lm.GemmaCausalLM.from_preset(**kwargs)
return cls(gemma)
Empty file.

0 comments on commit 4c4571f

Please sign in to comment.