diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 122654514..4b3d36a16 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -5,8 +5,10 @@ import uuid import time import json +import ctypes import fnmatch import multiprocessing + from typing import ( List, Optional, @@ -20,7 +22,6 @@ from collections import deque from pathlib import Path -import ctypes from llama_cpp.llama_types import List @@ -1789,7 +1790,7 @@ def save_state(self) -> LlamaState: state_size = llama_cpp.llama_get_state_size(self._ctx.ctx) if self.verbose: print(f"Llama.save_state: got state size: {state_size}", file=sys.stderr) - llama_state = (llama_cpp.c_uint8 * int(state_size))() + llama_state = (ctypes.c_uint8 * int(state_size))() if self.verbose: print("Llama.save_state: allocated state", file=sys.stderr) n_bytes = llama_cpp.llama_copy_state_data(self._ctx.ctx, llama_state) @@ -1797,7 +1798,7 @@ def save_state(self) -> LlamaState: print(f"Llama.save_state: copied llama state: {n_bytes}", file=sys.stderr) if int(n_bytes) > int(state_size): raise RuntimeError("Failed to copy llama state data") - llama_state_compact = (llama_cpp.c_uint8 * int(n_bytes))() + llama_state_compact = (ctypes.c_uint8 * int(n_bytes))() llama_cpp.ctypes.memmove(llama_state_compact, llama_state, int(n_bytes)) if self.verbose: print(