Skip to content

Commit

Permalink
AI chat: Added the debug llm strategy which gives more information ab…
Browse files Browse the repository at this point in the history
…out the llm generation process
  • Loading branch information
mostlikely4r committed Nov 22, 2024
1 parent 920b265 commit 7f9ed83
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 18 deletions.
80 changes: 69 additions & 11 deletions playerbot/PlayerbotLLMInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,18 +109,24 @@ std::string PlayerbotLLMInterface::Generate(const std::string& prompt, std::vect
sPlayerbotLLMInterface.generationCount++;

#ifdef _WIN32
// Initialize Winsock
if (debug)
debugLines.push_back("Initialize Winsock");

WSADATA wsaData;
if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) {
if (debug)
debugLines.push_back("WSAStartup failed");

sLog.outError("BotLLM: WSAStartup failed");
return "error";
}
#endif

// Parse the URL
ParsedUrl parsedUrl = sPlayerbotAIConfig.llmEndPointUrl;

// Resolve hostname to IP address
if (debug)
debugLines.push_back("Resolve hostname to IP address: " + parsedUrl.hostname + " " + std::to_string(parsedUrl.port));

struct addrinfo hints = {}, * res;
hints.ai_family = AF_INET;
hints.ai_socktype = SOCK_STREAM;
Expand All @@ -132,26 +138,37 @@ std::string PlayerbotLLMInterface::Generate(const std::string& prompt, std::vect
return "error";
}

// Create a socket
if (debug)
debugLines.push_back("Create a socket");
int sock;
#ifdef _WIN32
sock = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
if (sock == INVALID_SOCKET) {
if (debug)
debugLines.push_back("Socket creation failed");

sLog.outError("BotLLM: Socket creation failed");
WSACleanup();
return "error";
}
#else
sock = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
if (sock < 0) {
if (debug)
debugLines.push_back("Socket creation failed");
sLog.outError("BotLLM: Socket creation failed");
freeaddrinfo(res);
return "error";
}
#endif

// Connect to the server
if (debug)
debugLines.push_back("Connect to the server");

if (connect(sock, res->ai_addr, res->ai_addrlen) < 0) {
if (debug)
debugLines.push_back("Connection to server failed");

sLog.outError("BotLLM: Connection to server failed");
#ifdef _WIN32
closesocket(sock);
Expand All @@ -163,9 +180,8 @@ std::string PlayerbotLLMInterface::Generate(const std::string& prompt, std::vect
return "error";
}

freeaddrinfo(res); // Free the address info structure
freeaddrinfo(res);

// Create the HTTP POST request
std::ostringstream request;
request << "POST " << parsedUrl.path << " HTTP/1.1\r\n";
request << "Host: " << parsedUrl.hostname << "\r\n";
Expand All @@ -177,8 +193,13 @@ std::string PlayerbotLLMInterface::Generate(const std::string& prompt, std::vect
request << "\r\n";
request << body;

// Send the request
if (debug)
debugLines.push_back("Send the request" + request.str());

if (send(sock, request.str().c_str(), request.str().size(), 0) < 0) {
if (debug)
debugLines.push_back("Failed to send request");

sLog.outError("BotLLM: Failed to send request");
#ifdef _WIN32
closesocket(sock);
Expand All @@ -189,30 +210,45 @@ std::string PlayerbotLLMInterface::Generate(const std::string& prompt, std::vect
return "error";
}

// Read the response
if (debug)
debugLines.push_back("Read the response");

int bytesRead;

std::string response = RecvWithTimeout(sock, sPlayerbotAIConfig.llmGenerationTimeout, bytesRead);

#ifdef _WIN32
if (bytesRead == SOCKET_ERROR) {
if (debug)
debugLines.push_back("Error reading response");
sLog.outError("BotLLM: Error reading response");
}
closesocket(sock);
WSACleanup();
#else
if (bytesRead < 0) {
if (debug)
debugLines.push_back("Error reading response");
sLog.outError("BotLLM: Error reading response");
}
close(sock);
#endif

// Extract the response body (optional: depending on the server response format)
sPlayerbotLLMInterface.generationCount--;

if (debug)
{
if (!response.empty())
debugLines.push_back(response);
else
debugLines.push_back("Empty response");
}

size_t pos = response.find("\r\n\r\n");
if (pos != std::string::npos) {
response = response.substr(pos + 4);
if (debug)
debugLines.push_back(response);
}

return response;
Expand Down Expand Up @@ -279,18 +315,40 @@ inline std::vector<std::string> splitResponse(const std::string& response, const
return result;
}

std::vector<std::string> PlayerbotLLMInterface::ParseResponse(const std::string& response, const std::string& startPattern, const std::string& endPattern, const std::string& splitPattern)
std::vector<std::string> PlayerbotLLMInterface::ParseResponse(const std::string& response, const std::string& startPattern, const std::string& endPattern, const std::string& splitPattern, std::vector<std::string>& debugLines)
{
bool debug = !(debugLines.empty());
uint32 startCursor = 0;
uint32 endCursor = 0;

std::string actualResponse = response;

if (debug)
debugLines.push_back("start pattern:" + startPattern);

actualResponse = extractAfterPattern(actualResponse, startPattern);

if (!actualResponse.empty())
debugLines.push_back(actualResponse);
else
debugLines.push_back("Empty response");

if (debug)
debugLines.push_back("end pattern:" + endPattern);

actualResponse = extractBeforePattern(actualResponse, endPattern);

if (debug)
debugLines.push_back(actualResponse);

if (debug)
debugLines.push_back("split pattern:" + splitPattern);

std::vector<std::string> responses = splitResponse(actualResponse, splitPattern);

if (debug)
debugLines.insert(debugLines.end(), responses.begin(), responses.end());

return responses;
}

Expand Down
1 change: 0 additions & 1 deletion playerbot/PlayerbotLLMInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ class PlayerbotLLMInterface
{
public:
PlayerbotLLMInterface() {}
static std::string Generate(const std::string& prompt);

static std::string Generate(const std::string& prompt, std::vector<std::string>& debugLines);

Expand Down
1 change: 1 addition & 0 deletions playerbot/strategy/StrategyContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ namespace ai
creators["debug grind"] = &StrategyContext::debug_grind;
creators["debug loot"] = &StrategyContext::debug_loot;
creators["debug log"] = &StrategyContext::debug_log;
creators["debug llm"] = [](PlayerbotAI* ai) { return new DebugLLMStrategy(ai); };
creators["debug logname"] = &StrategyContext::debug_logname;
creators["rtsc"] = &StrategyContext::rtsc;
creators["rtsc jump"] = &StrategyContext::rtsc_jump;
Expand Down
53 changes: 47 additions & 6 deletions playerbot/strategy/actions/SayAction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ void ChatReplyAction::ChatReplyDo(Player* bot, uint32 type, uint32 guid1, uint32

std::string llmContext = AI_VALUE(std::string, "manual string::llmcontext" + llmChannel);

if (player && (player->isRealPlayer() || (sPlayerbotAIConfig.llmBotToBotChatChance && urand(0,99) < sPlayerbotAIConfig.llmBotToBotChatChance)))
if (player && (player->isRealPlayer() || (sPlayerbotAIConfig.llmBotToBotChatChance && urand(0, 99) < sPlayerbotAIConfig.llmBotToBotChatChance)))
{
std::string botName = bot->GetName();
std::string playerName = player->GetName();
Expand Down Expand Up @@ -271,13 +271,16 @@ void ChatReplyAction::ChatReplyDo(Player* bot, uint32 type, uint32 guid1, uint32
jsonFill["<post prompt>"] = BOT_TEXT2(sPlayerbotAIConfig.llmPostPrompt, placeholders);

uint32 currentLength = jsonFill["<pre prompt>"].size() + jsonFill["<context>"].size() + jsonFill["<prompt>"].size() + llmContext.size();

PlayerbotLLMInterface::LimitContext(llmContext, currentLength);

jsonFill["<context>"] = llmContext;

llmContext += " " + jsonFill["<prompt>"];

for (auto& prompt : jsonFill)
{
prompt.second = PlayerbotLLMInterface::SanitizeForJson(prompt.second);
}

std::string json = BOT_TEXT2(sPlayerbotAIConfig.llmApiJson, jsonFill);

json = BOT_TEXT2(json, placeholders);
Expand Down Expand Up @@ -312,9 +315,11 @@ void ChatReplyAction::ChatReplyDo(Player* bot, uint32 type, uint32 guid1, uint32
}
}

bool debug = bot->GetPlayerbotAI()->HasStrategy("debug llm", BotState::BOT_STATE_NON_COMBAT);

WorldSession* session = bot->GetSession();

std::future<std::vector<WorldPacket>> futurePackets = std::async([type, playerName, json, startPattern, endPattern, splitPattern] {
std::future<std::vector<WorldPacket>> futurePackets = std::async([type, playerName, json, startPattern, endPattern, splitPattern, debug] {

WorldPacket packet_template(CMSG_MESSAGECHAT);

Expand All @@ -326,11 +331,47 @@ void ChatReplyAction::ChatReplyDo(Player* bot, uint32 type, uint32 guid1, uint32
if (type == CHAT_MSG_WHISPER)
packet_template << playerName;

std::string response = PlayerbotLLMInterface::Generate(json);
std::vector<std::string> debugLines;

std::vector<std::string> lines = PlayerbotLLMInterface::ParseResponse(response, startPattern, endPattern, splitPattern);
if (debug)
debugLines = { json };

std::string response = PlayerbotLLMInterface::Generate(json, debugLines);

std::vector<std::string> lines = PlayerbotLLMInterface::ParseResponse(response, startPattern, endPattern, splitPattern, debugLines);

std::vector<WorldPacket> packets;

if (debug)
{
for (auto& line : debugLines)
{
std::string sentence = line;
while (sentence.length() > 200) {
size_t split_pos = sentence.rfind(' ', 200);
if (split_pos == std::string::npos) {
split_pos = 200;
}

if (!sentence.substr(0, split_pos).empty())
{
WorldPacket packet(packet_template);
packet << sentence.substr(0, split_pos);
packets.push_back(packet);
}

sentence = sentence.substr(split_pos + 1);
}

if (!sentence.empty())
{
WorldPacket packet(packet_template);
packet << sentence;
packets.push_back(packet);
}
}
}

for (auto& line : lines)
{
WorldPacket packet(packet_template);
Expand Down
16 changes: 16 additions & 0 deletions playerbot/strategy/generic/DebugStrategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,4 +195,20 @@ namespace ai
virtual std::vector<std::string> GetRelatedStrategies() { return { "" }; }
#endif
};

class DebugLLMStrategy : public Strategy
{
public:
DebugLLMStrategy(PlayerbotAI* ai) : Strategy(ai) {}
virtual int GetType() { return STRATEGY_TYPE_NONCOMBAT; }
virtual std::string getName() { return "debug llm"; }
#ifdef GenerateBotHelp
virtual std::string GetHelpName() { return "debug llm"; } //Must equal iternal name
virtual std::string GetHelpDescription() {
return "This strategy will give debug output while using the ai chat sytem";
}
virtual std::vector<std::string> GetRelatedStrategies() { return { "debug" }; }
#endif
};

}

0 comments on commit 7f9ed83

Please sign in to comment.