Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add per-module constraint #1844

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions dspy/adapters/chat_adapter.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
import re
from typing import Any, Union
from dsp.adapters.base_template import Field
from dspy.signatures.signature import Signature
from .base import Adapter
from .image_utils import encode_image, Image

import ast
import json
import enum
import inspect
import pydantic
import json
import re
import textwrap
from collections.abc import Mapping
from itertools import chain
from typing import Any, Dict, KeysView, List, Literal, NamedTuple, Union, get_args, get_origin

import pydantic
from pydantic import TypeAdapter
from collections.abc import Mapping
from pydantic.fields import FieldInfo
from typing import Dict, KeysView, List, Literal, NamedTuple, get_args, get_origin

from dsp.adapters.base_template import Field
from dspy.adapters.base import Adapter
from dspy.signatures.signature import Signature

from ..signatures.field import OutputField
from ..signatures.signature import SignatureMeta
from ..signatures.utils import get_dspy_field_type
from .base import Adapter
from .image_utils import Image, encode_image

field_header_pattern = re.compile(r"\[\[ ## (\w+) ## \]\]")

Expand Down Expand Up @@ -116,6 +116,8 @@ def format_fields(self, signature, values, role):


def format_blob(blob):
if not isinstance(blob, str):
blob = str(blob)
if "\n" not in blob and "«" not in blob and "»" not in blob:
return f"«{blob}»"

Expand Down
5 changes: 2 additions & 3 deletions dspy/predict/chain_of_thought.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,14 @@ def __init__(self, signature, rationale_type=None, activated=True, **config):
extended_signature = signature.prepend("reasoning", rationale_type, type_=str)
else:
extended_signature = signature.prepend("rationale", rationale_type, type_=str)

self._predict = dspy.Predict(extended_signature, **config)
self._predict.extended_signature = extended_signature

def forward(self, **kwargs):
assert self.activated in [True, False]

signature = kwargs.pop("new_signature", self._predict.extended_signature if self.activated else self.signature)
return self._predict(signature=signature, **kwargs)
return self._predict(**kwargs)

@property
def demos(self):
Expand Down
83 changes: 72 additions & 11 deletions dspy/predict/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,28 @@
from pydantic import BaseModel

import dsp
from dspy.adapters.image_utils import Image
from dspy.predict.parameter import Parameter
from dspy.primitives.prediction import Prediction
from dspy.primitives.program import Module
from dspy.signatures.signature import ensure_signature, signature_to_template
from dspy.signatures.signature import InputField, ensure_signature, signature_to_template
from dspy.utils.callback import with_callbacks
from dspy.adapters.image_utils import Image

logger = logging.getLogger(__name__)


@lru_cache(maxsize=None)
def warn_once(msg: str):
logging.warning(msg)


class Predict(Module, Parameter):
def __init__(self, signature, _parse_values=True, callbacks=None, **config):
def __init__(self, signature, _parse_values=True, callbacks=None, constraints=None, **config):
self.stage = random.randbytes(8).hex()
self.signature = ensure_signature(signature)
self.config = config
self.callbacks = callbacks or []
self.constraints = constraints or []
self._parse_values = _parse_values
self.reset()

Expand Down Expand Up @@ -74,7 +79,7 @@ def load_state(self, state, use_legacy_loading=False):
if use_legacy_loading:
self._load_state_legacy(state)
return self

if "signature" not in state:
# Check if the state is from a version of DSPy prior to v2.5.3.
raise ValueError(
Expand All @@ -87,7 +92,7 @@ def load_state(self, state, use_legacy_loading=False):
# `excluded_keys` are fields that go through special handling.
if name not in excluded_keys:
setattr(self, name, value)

# FIXME: Images are getting special treatment, but all basemodels initialized from json should be converted back to objects
for demo in self.demos:
for field in demo:
Expand All @@ -96,7 +101,7 @@ def load_state(self, state, use_legacy_loading=False):
if not isinstance(url, str):
raise ValueError(f"Image URL must be a string, got {type(url)}")
demo[field] = Image(url=url)

self.signature = self.signature.load_state(state["signature"])

if "extended_signature" in state:
Expand Down Expand Up @@ -138,20 +143,75 @@ def _load_state_legacy(self, state):

def load(self, path, return_self=False):
"""Load a saved state from a file.

Args:
path (str): Path to the saved state file
return_self (bool): If True, returns self to allow method chaining. Default is False for backwards compatibility.

Returns:
Union[None, Predict]: Returns None if return_self is False (default), returns self if return_self is True
"""
super().load(path)
return self if return_self else None

@with_callbacks
def __call__(self, **kwargs):
return self.forward(**kwargs)
def __call__(self, max_retries=1, soft_fail=False, **kwargs):
from dspy import settings as dspy_settings

if not self.constraints:
return self.forward(**kwargs)

original_signature = self.signature
# Include constraints and failed traces in the signature for retry mode.
retry_signature = (
self.signature.prepend("constraints", InputField(desc="Constraints to satisfy"))
.prepend(
"failed_trace", InputField(desc="Failed traces in previous attempts, empty if it's the first attempt")
)
.prepend(
"violated_constraints_indices",
InputField(desc="Indices (0-indexed) of violated constraints in previous attempts", type=list[int]),
)
)
self.signature = retry_signature
constraints_desc = [constraint.desc for constraint in self.constraints]
violated_constraints_indices = []
failed_trace = None
for retry_idx in range(max_retries):
outputs = self.forward(
constraints=constraints_desc,
failed_trace=failed_trace,
violated_constraints_indices=violated_constraints_indices,
**kwargs,
)
violated_constraints_indices = []
for i, constraint in enumerate(self.constraints):
if not constraint(inputs=kwargs, outputs=outputs):
logger.warning(
f"Constraint {i} is violated at retry {retry_idx}. Constraint description: {constraint.desc}. "
"A retry will be triggered."
)
violated_constraints_indices.append(i)
if len(violated_constraints_indices) > 0:
trace_info = dspy_settings.trace[-1]
failed_trace = {
"inputs": trace_info[1],
"outputs": trace_info[2].toDict(),
}

self.signature = original_signature
for violated_constraint_index in violated_constraints_indices:
if not soft_fail and not self.constraints[violated_constraint_index].soft:
raise ValueError(
f"Constraint {violated_constraints_indices} violated, terminating the program because this is a "
"hard-required constraint and you set `soft_fail=False`."
)
if len(violated_constraints_indices) > 0:
logger.warning(
f"Constraints {violated_constraints_indices} are violated, but since you set `soft_fail=True` or all the "
"violated constraints are soft constraints, the program will continue."
)
return outputs

def forward(self, **kwargs):
assert not dsp.settings.compiling, "It's no longer ever the case that .compiling is True"
Expand Down Expand Up @@ -185,7 +245,7 @@ def forward(self, **kwargs):
import dspy

if isinstance(lm, dspy.LM):
completions = v2_5_generate(lm, config, signature, demos, kwargs, _parse_values=self._parse_values)
completions = v2_5_generate(lm, config, signature, demos, inputs=kwargs, _parse_values=self._parse_values)
else:
warn_once(
"\t*** In DSPy 2.5, all LM clients except `dspy.LM` are deprecated, "
Expand Down Expand Up @@ -296,6 +356,7 @@ def v2_5_generate(lm, lm_kwargs, signature, demos, inputs, _parse_values=True):
lm, lm_kwargs=lm_kwargs, signature=signature, demos=demos, inputs=inputs, _parse_values=_parse_values
)


# TODO: get some defaults during init from the context window?
# # TODO: FIXME: Hmm, I guess expected behavior is that contexts can
# affect execution. Well, we need to determine whether context dominates, __init__ demoninates, or forward dominates.
Expand Down
12 changes: 12 additions & 0 deletions dspy/primitives/constraints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@

from typing import Callable


class Constraint:
def __init__(self, fn: Callable, desc: str, soft: bool = False):
self.fn = fn
self.desc = desc
self.soft = soft

def __call__(self, inputs, outputs):
return self.fn(inputs, outputs)