From 0515f292accd83b236b578202f53e3411d61ded1 Mon Sep 17 00:00:00 2001 From: henrycunh Date: Tue, 15 Aug 2023 08:46:03 -0300 Subject: [PATCH] feat: add cohere streaming and usage support --- .github/workflows/release.yml | 30 ++++++++ cursive/assets/price/cohere.py | 11 +++ cursive/build_input.py | 7 +- cursive/cursive.py | 49 ++++++++---- cursive/pricing.py | 4 +- cursive/usage/anthropic.py | 19 ++--- cursive/usage/cohere.py | 12 +++ cursive/vendor/cohere.py | 137 ++++++++++++++++++++++++++++++++- poetry.lock | 2 +- pyproject.toml | 3 +- 10 files changed, 233 insertions(+), 41 deletions(-) create mode 100644 .github/workflows/release.yml create mode 100644 cursive/assets/price/cohere.py create mode 100644 cursive/usage/cohere.py diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..758d03e --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,30 @@ +name: Release + +permissions: + contents: write + +on: + push: + tags: + - 'v*' + +jobs: + release: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Install pnpm + uses: pnpm/action-setup@v2 + + - name: Set node + uses: actions/setup-node@v3 + with: + node-version: 18.x + cache: pnpm + + - run: npx changelogithub + env: + GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}} \ No newline at end of file diff --git a/cursive/assets/price/cohere.py b/cursive/assets/price/cohere.py new file mode 100644 index 0000000..af371af --- /dev/null +++ b/cursive/assets/price/cohere.py @@ -0,0 +1,11 @@ +COHERE_PRICING = { + "version": "2023-07-11", + "command": { + "completion": 0.015, + "prompt": 0.015 + }, + "command-nightly": { + "completion": 0.015, + "prompt": 0.015 + }, +} \ No newline at end of file diff --git a/cursive/build_input.py b/cursive/build_input.py index 09371e0..53fad3c 100644 --- a/cursive/build_input.py +++ b/cursive/build_input.py @@ -55,12 +55,7 @@ def get_function_call_directives(functions: list[CursiveFunction]) -> str: If you need to use a function, always output the result of the function call using the tag using the following format: - {'{'} - "name": "function_name", - "arguments": {'{'} - "argument_name": "argument_value" - {'}'} - {'}'} + {'{'}"name": "function_name", "arguments": {'{'}"argument_name": "argument_value"{'}'}{'}'} Never escape the function call, always output it as it is. ALWAYS use this format, even if the function doesn't have arguments. The arguments prop is always a dictionary. diff --git a/cursive/cursive.py b/cursive/cursive.py index ac34efb..b74298a 100644 --- a/cursive/cursive.py +++ b/cursive/cursive.py @@ -8,7 +8,8 @@ from cursive.build_input import get_function_call_directives from cursive.function import parse_custom_function_call -from cursive.vendor.cohere import CohereClient +from cursive.usage.cohere import get_cohere_usage +from cursive.vendor.cohere import CohereClient, process_cohere_stream from .custom_types import ( BaseModel, @@ -346,12 +347,7 @@ def create_completion( ) - data = { - 'choices': [{ 'message': { 'content': response.completion.lstrip() } }], - 'model': payload.model, - 'id': random_id(), - 'usage': {}, - } + if payload.stream: data = process_anthropic_stream( payload=payload, @@ -359,6 +355,13 @@ def create_completion( response=response, on_token=on_token, ) + else: + data = { + 'choices': [{ 'message': { 'content': response.completion.lstrip() } }], + 'model': payload.model, + 'id': random_id(), + 'usage': {}, + } parse_custom_function_call(data, payload, get_anthropic_usage) @@ -380,19 +383,33 @@ def create_completion( code=CursiveErrorCode.completion_error ) - data = { - 'choices': [{ 'message': { 'content': response.data[0].text.lstrip() } }], - 'model': payload.model, - 'id': random_id(), - 'usage': {}, - } - if payload.stream: # TODO: Implement stream processing for Cohere - pass + data = process_cohere_stream( + payload=payload, + cursive=cursive, + response=response, + on_token=on_token, + ) + else: + data = { + 'choices': [{ 'message': { 'content': response.data[0].text.lstrip() } }], + 'model': payload.model, + 'id': random_id(), + 'usage': {}, + } - parse_custom_function_call(data, payload) + parse_custom_function_call(data, payload, get_cohere_usage) + data['cost'] = resolve_pricing( + vendor='cohere', + usage=CursiveAskUsage( + completion_tokens=data['usage']['completion_tokens'], + prompt_tokens=data['usage']['prompt_tokens'], + total_tokens=data['usage']['total_tokens'], + ), + model=data['model'] + ) end = time.time() if data.get('error'): diff --git a/cursive/pricing.py b/cursive/pricing.py index 0e9c6f0..6505132 100644 --- a/cursive/pricing.py +++ b/cursive/pricing.py @@ -5,10 +5,12 @@ from .assets.price.anthropic import ANTHROPIC_PRICING from .assets.price.openai import OPENAI_PRICING +from .assets.price.cohere import COHERE_PRICING VENDOR_PRICING = { 'openai': OPENAI_PRICING, - 'anthropic': ANTHROPIC_PRICING + 'anthropic': ANTHROPIC_PRICING, + 'cohere': COHERE_PRICING } def resolve_pricing( diff --git a/cursive/usage/anthropic.py b/cursive/usage/anthropic.py index 08cbdac..774052e 100644 --- a/cursive/usage/anthropic.py +++ b/cursive/usage/anthropic.py @@ -1,23 +1,14 @@ from anthropic import Anthropic +from cursive.build_input import build_completion_input + from ..custom_types import CompletionMessage def get_anthropic_usage(content: str | list[CompletionMessage]): client = Anthropic() - if type(content) == str: - return client.count_tokens(content) - - def function(message: CompletionMessage): - if message.role == 'system': - return f''' - Human: {message.content} - - Assistant: Ok. - ''' - return f'{message.role}: {message.content}' - - mapped_content = '\n\n'.join(list(map(function, content))) # type: ignore + if type(content) != str: + content = build_completion_input(content) - return client.count_tokens(mapped_content) + return client.count_tokens(content) diff --git a/cursive/usage/cohere.py b/cursive/usage/cohere.py new file mode 100644 index 0000000..5bd1f86 --- /dev/null +++ b/cursive/usage/cohere.py @@ -0,0 +1,12 @@ +from tokenizers import Tokenizer +from cursive.build_input import build_completion_input +from cursive.custom_types import CompletionMessage + + +def get_cohere_usage(content: str | list[CompletionMessage]): + tokenizer = Tokenizer.from_pretrained("Cohere/command-nightly") + + if type(content) != str: + content = build_completion_input(content) + + return len(tokenizer.encode(content).ids) \ No newline at end of file diff --git a/cursive/vendor/cohere.py b/cursive/vendor/cohere.py index ee474e8..a9f1b48 100644 --- a/cursive/vendor/cohere.py +++ b/cursive/vendor/cohere.py @@ -1,8 +1,10 @@ +import re +from typing import Any, Optional from cohere import Client as Cohere from cursive.build_input import build_completion_input -from cursive.custom_types import CompletionPayload -from cursive.utils import filter_null_values +from cursive.custom_types import CompletionPayload, CursiveAskOnToken +from cursive.utils import filter_null_values, random_id class CohereClient: client: Cohere @@ -34,3 +36,134 @@ def create_completion(self, payload: CompletionPayload): +def process_cohere_stream( + payload: CompletionPayload, + cursive: Any, + response: Any, + on_token: Optional[CursiveAskOnToken] = None, +): + data = { + 'choices': [{ 'message': { 'content': '' } }], + 'usage': { + 'completion_tokens': 0, + 'prompt_tokens': 0, + }, + 'model': payload.model, + } + + completion = '' + for slice in response: + data = { + **data, + 'id': random_id(), + } + + # The completion partial will come with a leading whitespace + completion += slice.text + if not data['choices'][0]['message']['content']: + completion = completion.lstrip() + + # Check if theres any tag. The regex should allow for nested tags + function_call_tag = re.findall( + r'([\s\S]*?)(?=<\/function-call>|$)', + completion + ) + function_name = '' + function_arguments = '' + if len(function_call_tag) > 0: + # Remove starting and ending tags, even if the ending tag is partial or missing + function_call = re.sub( + r'^\n|\n$', + '', + re.sub( + r'<\/?f?u?n?c?t?i?o?n?-?c?a?l?l?>?', + '', + function_call_tag[0] + ).strip() + ).strip() + # Match the function name inside the JSON + function_name_matches = re.findall( + r'"name":\s*"(.+)"', + function_call + ) + function_name = len(function_name_matches) > 0 and function_name_matches[0] + function_arguments_matches = re.findall( + r'"arguments":\s*(\{.+)\}?', + function_call, + re.S + ) + function_arguments = ( + len(function_arguments_matches) > 0 and + function_arguments_matches[0] + ) + if function_arguments: + # If theres unmatches } at the end, remove them + unmatched_brackets = re.findall( + r'(\{|\})', + function_arguments + ) + if len(unmatched_brackets) % 2: + function_arguments = re.sub( + r'\}$', + '', + function_arguments.strip() + ) + + function_arguments = function_arguments.strip() + + cursive_answer_tag = re.findall( + r'([\s\S]*?)(?=<\/cursive-answer>|$)', + completion + ) + tagged_answer = '' + if cursive_answer_tag: + tagged_answer = re.sub( + r'<\/?c?u?r?s?i?v?e?-?a?n?s?w?e?r?>?', + '', + cursive_answer_tag[0] + ).lstrip() + + current_token = completion[ + len(data['choices'][0]['message']['content']): + ] + + data['choices'][0]['message']['content'] += current_token + + if on_token: + chunk = None + + if payload.functions: + if function_name: + chunk = { + 'function_call': {}, + 'content': None, + } + if function_arguments: + # Remove all but the current token from the arguments + chunk['function_call']['arguments'] = function_arguments + else: + chunk['function_call'] = { + 'name': function_name, + 'arguments': '', + } + elif tagged_answer: + # Token is at the end of the tagged answer + regex = fr'(.*){current_token.strip()}$' + match = re.findall( + regex, + tagged_answer + ) + if len(match) > 0 and current_token: + chunk = { + 'function_call': None, + 'content': current_token, + } + else: + chunk = { + 'content': current_token, + } + + if chunk: + on_token(chunk) + + return data diff --git a/poetry.lock b/poetry.lock index 078d432..03b8917 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1155,4 +1155,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.10.9,<3.12" -content-hash = "11991b232615f9ade19fdad4faf80734ae7eadceeef326e48ce7c51f122da126" +content-hash = "56b699ca6fc6bdbd0d78fd1988e81a1b192140649acca4a2368ed82aedf28822" diff --git a/pyproject.toml b/pyproject.toml index 38208b0..00261ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "cursivepy" -version = "0.1.6" +version = "0.3.0" description = "" authors = ["Rodrigo Godinho ", "Henrique Cunha "] readme = "README.md" @@ -13,6 +13,7 @@ openai = "^0.27.8" anthropic = "^0.3.6" pydantic = "^1.9.0" cohere = "^4.19.3" +tokenizers = "^0.13.3" [tool.poetry.group.dev.dependencies] pytest = "^7.4.0"