Skip to content

Commit

Permalink
feat: add cohere streaming and usage support
Browse files Browse the repository at this point in the history
  • Loading branch information
henrycunh committed Aug 15, 2023
1 parent 8367ee7 commit 0515f29
Show file tree
Hide file tree
Showing 10 changed files with 233 additions and 41 deletions.
30 changes: 30 additions & 0 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
name: Release

permissions:
contents: write

on:
push:
tags:
- 'v*'

jobs:
release:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0

- name: Install pnpm
uses: pnpm/action-setup@v2

- name: Set node
uses: actions/setup-node@v3
with:
node-version: 18.x
cache: pnpm

- run: npx changelogithub
env:
GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}}
11 changes: 11 additions & 0 deletions cursive/assets/price/cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
COHERE_PRICING = {
"version": "2023-07-11",
"command": {
"completion": 0.015,
"prompt": 0.015
},
"command-nightly": {
"completion": 0.015,
"prompt": 0.015
},
}
7 changes: 1 addition & 6 deletions cursive/build_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,7 @@ def get_function_call_directives(functions: list[CursiveFunction]) -> str:
If you need to use a function, always output the result of the function call using the <function-call> tag using the following format:
<function-call>
{'{'}
"name": "function_name",
"arguments": {'{'}
"argument_name": "argument_value"
{'}'}
{'}'}
{'{'}"name": "function_name", "arguments": {'{'}"argument_name": "argument_value"{'}'}{'}'}
</function-call>
Never escape the function call, always output it as it is.
ALWAYS use this format, even if the function doesn't have arguments. The arguments prop is always a dictionary.
Expand Down
49 changes: 33 additions & 16 deletions cursive/cursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from cursive.build_input import get_function_call_directives
from cursive.function import parse_custom_function_call
from cursive.vendor.cohere import CohereClient
from cursive.usage.cohere import get_cohere_usage
from cursive.vendor.cohere import CohereClient, process_cohere_stream

from .custom_types import (
BaseModel,
Expand Down Expand Up @@ -346,19 +347,21 @@ def create_completion(
)


data = {
'choices': [{ 'message': { 'content': response.completion.lstrip() } }],
'model': payload.model,
'id': random_id(),
'usage': {},
}

if payload.stream:
data = process_anthropic_stream(
payload=payload,
cursive=cursive,
response=response,
on_token=on_token,
)
else:
data = {
'choices': [{ 'message': { 'content': response.completion.lstrip() } }],
'model': payload.model,
'id': random_id(),
'usage': {},
}

parse_custom_function_call(data, payload, get_anthropic_usage)

Expand All @@ -380,19 +383,33 @@ def create_completion(
code=CursiveErrorCode.completion_error
)

data = {
'choices': [{ 'message': { 'content': response.data[0].text.lstrip() } }],
'model': payload.model,
'id': random_id(),
'usage': {},
}

if payload.stream:
# TODO: Implement stream processing for Cohere
pass
data = process_cohere_stream(
payload=payload,
cursive=cursive,
response=response,
on_token=on_token,
)
else:
data = {
'choices': [{ 'message': { 'content': response.data[0].text.lstrip() } }],
'model': payload.model,
'id': random_id(),
'usage': {},
}

parse_custom_function_call(data, payload)
parse_custom_function_call(data, payload, get_cohere_usage)

data['cost'] = resolve_pricing(
vendor='cohere',
usage=CursiveAskUsage(
completion_tokens=data['usage']['completion_tokens'],
prompt_tokens=data['usage']['prompt_tokens'],
total_tokens=data['usage']['total_tokens'],
),
model=data['model']
)
end = time.time()

if data.get('error'):
Expand Down
4 changes: 3 additions & 1 deletion cursive/pricing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

from .assets.price.anthropic import ANTHROPIC_PRICING
from .assets.price.openai import OPENAI_PRICING
from .assets.price.cohere import COHERE_PRICING

VENDOR_PRICING = {
'openai': OPENAI_PRICING,
'anthropic': ANTHROPIC_PRICING
'anthropic': ANTHROPIC_PRICING,
'cohere': COHERE_PRICING
}

def resolve_pricing(
Expand Down
19 changes: 5 additions & 14 deletions cursive/usage/anthropic.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,14 @@
from anthropic import Anthropic

from cursive.build_input import build_completion_input

from ..custom_types import CompletionMessage

def get_anthropic_usage(content: str | list[CompletionMessage]):
client = Anthropic()

if type(content) == str:
return client.count_tokens(content)

def function(message: CompletionMessage):
if message.role == 'system':
return f'''
Human: {message.content}
Assistant: Ok.
'''
return f'{message.role}: {message.content}'

mapped_content = '\n\n'.join(list(map(function, content))) # type: ignore
if type(content) != str:
content = build_completion_input(content)

return client.count_tokens(mapped_content)
return client.count_tokens(content)

12 changes: 12 additions & 0 deletions cursive/usage/cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from tokenizers import Tokenizer
from cursive.build_input import build_completion_input
from cursive.custom_types import CompletionMessage


def get_cohere_usage(content: str | list[CompletionMessage]):
tokenizer = Tokenizer.from_pretrained("Cohere/command-nightly")

if type(content) != str:
content = build_completion_input(content)

return len(tokenizer.encode(content).ids)
137 changes: 135 additions & 2 deletions cursive/vendor/cohere.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import re
from typing import Any, Optional
from cohere import Client as Cohere
from cursive.build_input import build_completion_input

from cursive.custom_types import CompletionPayload
from cursive.utils import filter_null_values
from cursive.custom_types import CompletionPayload, CursiveAskOnToken
from cursive.utils import filter_null_values, random_id

class CohereClient:
client: Cohere
Expand Down Expand Up @@ -34,3 +36,134 @@ def create_completion(self, payload: CompletionPayload):



def process_cohere_stream(
payload: CompletionPayload,
cursive: Any,
response: Any,
on_token: Optional[CursiveAskOnToken] = None,
):
data = {
'choices': [{ 'message': { 'content': '' } }],
'usage': {
'completion_tokens': 0,
'prompt_tokens': 0,
},
'model': payload.model,
}

completion = ''
for slice in response:
data = {
**data,
'id': random_id(),
}

# The completion partial will come with a leading whitespace
completion += slice.text
if not data['choices'][0]['message']['content']:
completion = completion.lstrip()

# Check if theres any <function-call> tag. The regex should allow for nested tags
function_call_tag = re.findall(
r'<function-call>([\s\S]*?)(?=<\/function-call>|$)',
completion
)
function_name = ''
function_arguments = ''
if len(function_call_tag) > 0:
# Remove <function-call> starting and ending tags, even if the ending tag is partial or missing
function_call = re.sub(
r'^\n|\n$',
'',
re.sub(
r'<\/?f?u?n?c?t?i?o?n?-?c?a?l?l?>?',
'',
function_call_tag[0]
).strip()
).strip()
# Match the function name inside the JSON
function_name_matches = re.findall(
r'"name":\s*"(.+)"',
function_call
)
function_name = len(function_name_matches) > 0 and function_name_matches[0]
function_arguments_matches = re.findall(
r'"arguments":\s*(\{.+)\}?',
function_call,
re.S
)
function_arguments = (
len(function_arguments_matches) > 0 and
function_arguments_matches[0]
)
if function_arguments:
# If theres unmatches } at the end, remove them
unmatched_brackets = re.findall(
r'(\{|\})',
function_arguments
)
if len(unmatched_brackets) % 2:
function_arguments = re.sub(
r'\}$',
'',
function_arguments.strip()
)

function_arguments = function_arguments.strip()

cursive_answer_tag = re.findall(
r'<cursive-answer>([\s\S]*?)(?=<\/cursive-answer>|$)',
completion
)
tagged_answer = ''
if cursive_answer_tag:
tagged_answer = re.sub(
r'<\/?c?u?r?s?i?v?e?-?a?n?s?w?e?r?>?',
'',
cursive_answer_tag[0]
).lstrip()

current_token = completion[
len(data['choices'][0]['message']['content']):
]

data['choices'][0]['message']['content'] += current_token

if on_token:
chunk = None

if payload.functions:
if function_name:
chunk = {
'function_call': {},
'content': None,
}
if function_arguments:
# Remove all but the current token from the arguments
chunk['function_call']['arguments'] = function_arguments
else:
chunk['function_call'] = {
'name': function_name,
'arguments': '',
}
elif tagged_answer:
# Token is at the end of the tagged answer
regex = fr'(.*){current_token.strip()}$'
match = re.findall(
regex,
tagged_answer
)
if len(match) > 0 and current_token:
chunk = {
'function_call': None,
'content': current_token,
}
else:
chunk = {
'content': current_token,
}

if chunk:
on_token(chunk)

return data
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "cursivepy"
version = "0.1.6"
version = "0.3.0"
description = ""
authors = ["Rodrigo Godinho <[email protected]>", "Henrique Cunha <[email protected]>"]
readme = "README.md"
Expand All @@ -13,6 +13,7 @@ openai = "^0.27.8"
anthropic = "^0.3.6"
pydantic = "^1.9.0"
cohere = "^4.19.3"
tokenizers = "^0.13.3"

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.0"
Expand Down

0 comments on commit 0515f29

Please sign in to comment.