Skip to content

Commit

Permalink
move aetherroom import to separate module
Browse files Browse the repository at this point in the history
  • Loading branch information
whjms committed Mar 9, 2023
1 parent d7854e9 commit ad2c2b6
Show file tree
Hide file tree
Showing 5 changed files with 230 additions and 31 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,17 @@ Uninstall
flask_session
accelerate-disk-cache
.ipynb_checkpoints
unit_test_report.html

# Temporary until HF port
!models/RWKV-v4
models/RWKV-v4/20B_tokenizer.json
models/RWKV-v4/src/__pycache__
models/RWKV-v4/models

# Ignore PyCharm project files.
# Ignore PyCharm, VSCode project files.
.idea
.vscode

# Ignore compiled Python files.
*.pyc
Expand Down
45 changes: 15 additions & 30 deletions aiserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,40 +414,25 @@ def replace_placeholders(self, ph_ids: dict):
self.world_infos[i][key] = self._replace_placeholders(self.world_infos[i][key])

def from_club(self, club_id):
# Maybe it is a better to parse the NAI Scenario (if available), it has more data
r = requests.get(f"https://aetherroom.club/api/{club_id}")

if not r.ok:
print(f"[import] Got {r.status_code} on request to club :^(")
message = f"Club responded with {r.status_code}"
if r.status_code == "404":
from importers import aetherroom
import_data: aetherroom.ImportData
try:
import_data = aetherroom.import_scenario(club_id)
except aetherroom.RequestFailed as err:
status = err.status_code
print(f"[import] Got {status} on request to club :^(")
message = f"Club responded with {status}"
if status == "404":
message = f"Prompt not found for ID {club_id}"
show_error_notification("Error loading prompt", message)
return

j = r.json()

self.prompt = j["promptContent"]
self.memory = j["memory"]
self.authors_note = j["authorsNote"]
self.notes = j["description"]
self.title = j["title"] or "Imported Story"

self.world_infos = []

for wi in j["worldInfos"]:
self.world_infos.append({
"key_list": wi["keysList"],
"keysecondary": [],
"content": wi["entry"],
"comment": "",
"folder": wi.get("folder", None),
"num": 0,
"init": True,
"selective": wi.get("selective", False),
"constant": wi.get("constant", False),
"uid": None,
})
self.prompt = import_data.prompt
self.memory = import_data.memory
self.authors_note = import_data.authors_note
self.notes = import_data.notes
self.title = import_data.title
self.world_infos = import_data.world_infos

placeholders = self.extract_placeholders(self.prompt)
if not placeholders:
Expand Down
57 changes: 57 additions & 0 deletions importers/aetherroom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from dataclasses import dataclass
import requests
from typing import List

BASE_URL = "https://aetherroom.club/api/"


@dataclass
class ImportData:
prompt: str
memory: str
authors_note: str
notes: str
title: str
world_infos: List[object]


class RequestFailed(Exception):
def __init__(self, status_code: str) -> None:
self.status_code = status_code
super().__init__()


def import_scenario(id: int) -> ImportData:
"""
Fetches story info from the provided AetherRoom scenario ID.
"""
# Maybe it is a better to parse the NAI Scenario (if available), it has more data
req = requests.get(f"{BASE_URL}{id}")
if not req.ok:
raise RequestFailed(req.status_code)

json = req.json()
prompt = json["promptContent"]
memory = json["memory"]
authors_note = json["authorsNote"]
notes = json["description"]
title = json.get("title", "Imported Story")

world_infos = []
for info in json["worldinfos"]:
world_infos.append(
{
"key_list": info["keysList"],
"keysecondary": [],
"content": info["entry"],
"comment": "",
"folder": info.get("folder", None),
"num": 0,
"init": True,
"selective": info.get("selective", False),
"constant": info.get("constant", False),
"uid": None,
}
)

return ImportData(prompt, memory, authors_note, notes, title, world_infos)
149 changes: 149 additions & 0 deletions importers/test_aetherroom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import pytest
import requests_mock

from importers.aetherroom import (
ImportData,
RequestFailed,
import_scenario,
)


def test_import_scenario_http_error(requests_mock: requests_mock.mocker):
requests_mock.get("https://aetherroom.club/api/1", status_code=404)
with pytest.raises(RequestFailed):
import_scenario(1)


def test_import_scenario_success(requests_mock: requests_mock.Mocker):
json = {
"promptContent": "promptContent",
"memory": "memory",
"authorsNote": "authorsNote",
"description": "description",
"title": "title",
"worldinfos": [],
}
requests_mock.get("https://aetherroom.club/api/1", json=json)

expected_import_data = ImportData(
"promptContent", "memory", "authorsNote", "description", "title", []
)
assert import_scenario(1) == expected_import_data


def test_import_scenario_no_title(requests_mock: requests_mock.Mocker):
json = {
"promptContent": "promptContent",
"memory": "memory",
"authorsNote": "authorsNote",
"description": "description",
"worldinfos": [],
}
requests_mock.get("https://aetherroom.club/api/1", json=json)

expected_import_data = ImportData(
"promptContent", "memory", "authorsNote", "description", "Imported Story", []
)
assert import_scenario(1) == expected_import_data


def test_import_scenario_world_infos(requests_mock: requests_mock.Mocker):
json = {
"promptContent": "promptContent",
"memory": "memory",
"authorsNote": "authorsNote",
"description": "description",
"worldinfos": [
{
"entry": "Info 1",
"keysList": ["a", "b", "c"],
"folder": "folder",
"selective": True,
"constant": True,
},
{
"entry": "Info 2",
"keysList": ["d", "e", "f"],
"folder": "folder 2",
"selective": True,
"constant": True,
},
],
}
requests_mock.get("https://aetherroom.club/api/1", json=json)

expected_import_data = ImportData(
"promptContent",
"memory",
"authorsNote",
"description",
"Imported Story",
[
{
"content": "Info 1",
"key_list": ["a", "b", "c"],
"keysecondary": [],
"comment": "",
"num": 0,
"init": True,
"uid": None,
"folder": "folder",
"selective": True,
"constant": True,
},
{
"content": "Info 2",
"key_list": ["d", "e", "f"],
"keysecondary": [],
"comment": "",
"num": 0,
"init": True,
"uid": None,
"folder": "folder 2",
"selective": True,
"constant": True,
},
],
)
assert import_scenario(1) == expected_import_data


def test_import_scenario_world_info_missing_properties(
requests_mock: requests_mock.Mocker,
):
json = {
"promptContent": "promptContent",
"memory": "memory",
"authorsNote": "authorsNote",
"description": "description",
"worldinfos": [
{
"entry": "Info 1",
"keysList": ["a", "b", "c"],
}
],
}
requests_mock.get("https://aetherroom.club/api/1", json=json)

expected_import_data = ImportData(
"promptContent",
"memory",
"authorsNote",
"description",
"Imported Story",
[
{
"content": "Info 1",
"key_list": ["a", "b", "c"],
"keysecondary": [],
"comment": "",
"num": 0,
"init": True,
"uid": None,
"folder": None,
"selective": False,
"constant": False,
}
],
)
assert import_scenario(1) == expected_import_data
6 changes: 6 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ dnspython==2.2.1
lupa==1.10
markdown
bleach==4.1.0
black
sentencepiece
protobuf
accelerate
Expand All @@ -29,5 +30,10 @@ flask_compress
ijson
bitsandbytes
ftfy
py==1.11.0
pydub
pytest==7.2.2
pytest-html==3.2.0
pytest-metadata==2.0.4
requests-mock==1.10.0
safetensors

0 comments on commit ad2c2b6

Please sign in to comment.