forked from henk717/KoboldAI
-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
move aetherroom import to separate module
- Loading branch information
whjms
committed
Mar 9, 2023
1 parent
d7854e9
commit ad2c2b6
Showing
5 changed files
with
230 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from dataclasses import dataclass | ||
import requests | ||
from typing import List | ||
|
||
BASE_URL = "https://aetherroom.club/api/" | ||
|
||
|
||
@dataclass | ||
class ImportData: | ||
prompt: str | ||
memory: str | ||
authors_note: str | ||
notes: str | ||
title: str | ||
world_infos: List[object] | ||
|
||
|
||
class RequestFailed(Exception): | ||
def __init__(self, status_code: str) -> None: | ||
self.status_code = status_code | ||
super().__init__() | ||
|
||
|
||
def import_scenario(id: int) -> ImportData: | ||
""" | ||
Fetches story info from the provided AetherRoom scenario ID. | ||
""" | ||
# Maybe it is a better to parse the NAI Scenario (if available), it has more data | ||
req = requests.get(f"{BASE_URL}{id}") | ||
if not req.ok: | ||
raise RequestFailed(req.status_code) | ||
|
||
json = req.json() | ||
prompt = json["promptContent"] | ||
memory = json["memory"] | ||
authors_note = json["authorsNote"] | ||
notes = json["description"] | ||
title = json.get("title", "Imported Story") | ||
|
||
world_infos = [] | ||
for info in json["worldinfos"]: | ||
world_infos.append( | ||
{ | ||
"key_list": info["keysList"], | ||
"keysecondary": [], | ||
"content": info["entry"], | ||
"comment": "", | ||
"folder": info.get("folder", None), | ||
"num": 0, | ||
"init": True, | ||
"selective": info.get("selective", False), | ||
"constant": info.get("constant", False), | ||
"uid": None, | ||
} | ||
) | ||
|
||
return ImportData(prompt, memory, authors_note, notes, title, world_infos) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
import pytest | ||
import requests_mock | ||
|
||
from importers.aetherroom import ( | ||
ImportData, | ||
RequestFailed, | ||
import_scenario, | ||
) | ||
|
||
|
||
def test_import_scenario_http_error(requests_mock: requests_mock.mocker): | ||
requests_mock.get("https://aetherroom.club/api/1", status_code=404) | ||
with pytest.raises(RequestFailed): | ||
import_scenario(1) | ||
|
||
|
||
def test_import_scenario_success(requests_mock: requests_mock.Mocker): | ||
json = { | ||
"promptContent": "promptContent", | ||
"memory": "memory", | ||
"authorsNote": "authorsNote", | ||
"description": "description", | ||
"title": "title", | ||
"worldinfos": [], | ||
} | ||
requests_mock.get("https://aetherroom.club/api/1", json=json) | ||
|
||
expected_import_data = ImportData( | ||
"promptContent", "memory", "authorsNote", "description", "title", [] | ||
) | ||
assert import_scenario(1) == expected_import_data | ||
|
||
|
||
def test_import_scenario_no_title(requests_mock: requests_mock.Mocker): | ||
json = { | ||
"promptContent": "promptContent", | ||
"memory": "memory", | ||
"authorsNote": "authorsNote", | ||
"description": "description", | ||
"worldinfos": [], | ||
} | ||
requests_mock.get("https://aetherroom.club/api/1", json=json) | ||
|
||
expected_import_data = ImportData( | ||
"promptContent", "memory", "authorsNote", "description", "Imported Story", [] | ||
) | ||
assert import_scenario(1) == expected_import_data | ||
|
||
|
||
def test_import_scenario_world_infos(requests_mock: requests_mock.Mocker): | ||
json = { | ||
"promptContent": "promptContent", | ||
"memory": "memory", | ||
"authorsNote": "authorsNote", | ||
"description": "description", | ||
"worldinfos": [ | ||
{ | ||
"entry": "Info 1", | ||
"keysList": ["a", "b", "c"], | ||
"folder": "folder", | ||
"selective": True, | ||
"constant": True, | ||
}, | ||
{ | ||
"entry": "Info 2", | ||
"keysList": ["d", "e", "f"], | ||
"folder": "folder 2", | ||
"selective": True, | ||
"constant": True, | ||
}, | ||
], | ||
} | ||
requests_mock.get("https://aetherroom.club/api/1", json=json) | ||
|
||
expected_import_data = ImportData( | ||
"promptContent", | ||
"memory", | ||
"authorsNote", | ||
"description", | ||
"Imported Story", | ||
[ | ||
{ | ||
"content": "Info 1", | ||
"key_list": ["a", "b", "c"], | ||
"keysecondary": [], | ||
"comment": "", | ||
"num": 0, | ||
"init": True, | ||
"uid": None, | ||
"folder": "folder", | ||
"selective": True, | ||
"constant": True, | ||
}, | ||
{ | ||
"content": "Info 2", | ||
"key_list": ["d", "e", "f"], | ||
"keysecondary": [], | ||
"comment": "", | ||
"num": 0, | ||
"init": True, | ||
"uid": None, | ||
"folder": "folder 2", | ||
"selective": True, | ||
"constant": True, | ||
}, | ||
], | ||
) | ||
assert import_scenario(1) == expected_import_data | ||
|
||
|
||
def test_import_scenario_world_info_missing_properties( | ||
requests_mock: requests_mock.Mocker, | ||
): | ||
json = { | ||
"promptContent": "promptContent", | ||
"memory": "memory", | ||
"authorsNote": "authorsNote", | ||
"description": "description", | ||
"worldinfos": [ | ||
{ | ||
"entry": "Info 1", | ||
"keysList": ["a", "b", "c"], | ||
} | ||
], | ||
} | ||
requests_mock.get("https://aetherroom.club/api/1", json=json) | ||
|
||
expected_import_data = ImportData( | ||
"promptContent", | ||
"memory", | ||
"authorsNote", | ||
"description", | ||
"Imported Story", | ||
[ | ||
{ | ||
"content": "Info 1", | ||
"key_list": ["a", "b", "c"], | ||
"keysecondary": [], | ||
"comment": "", | ||
"num": 0, | ||
"init": True, | ||
"uid": None, | ||
"folder": None, | ||
"selective": False, | ||
"constant": False, | ||
} | ||
], | ||
) | ||
assert import_scenario(1) == expected_import_data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters