Skip to content

Commit

Permalink
fix the saving of TypePredictor (#1651)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenmoneygithub authored Oct 20, 2024
1 parent 9a952fe commit ef83201
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
10 changes: 9 additions & 1 deletion dspy/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,20 @@ def __init__(self, signature, instructions=None, *, max_retries=3, wrap_json=Fal
explain_errors: If True, the model will try to explain the errors it encounters.
"""
super().__init__()
self.signature = ensure_signature(signature, instructions)
signature = ensure_signature(signature, instructions)
self.predictor = dspy.Predict(signature, _parse_values=False)
self.max_retries = max_retries
self.wrap_json = wrap_json
self.explain_errors = explain_errors

@property
def signature(self) -> dspy.Signature:
return self.predictor.signature

@signature.setter
def signature(self, value: dspy.Signature):
self.predictor.signature = value

def copy(self) -> "TypedPredictor":
return TypedPredictor(
self.signature,
Expand Down
20 changes: 20 additions & 0 deletions tests/functional/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,3 +822,23 @@ def check_cateogry(self):

pred = predictor(input_data="What is the best animal?", allowed_categories=["cat", "dog"])
assert pred.category == "dog"

def test_save_type_predictor(tmp_path):
class MySignature(dspy.Signature):
"""I am a benigh signature."""
question: str = dspy.InputField()
answer: str = dspy.OutputField()

class CustomModel(dspy.Module):
def __init__(self):
self.predictor = dspy.TypedPredictor(MySignature)

save_path = tmp_path / "state.json"
model = CustomModel()
model.predictor.signature = MySignature.with_instructions("I am a malicious signature.")
model.save(save_path)

loaded = CustomModel()
assert loaded.predictor.signature.instructions == "I am a benigh signature."
loaded.load(save_path)
assert loaded.predictor.signature.instructions == "I am a malicious signature."

0 comments on commit ef83201

Please sign in to comment.