-
Notifications
You must be signed in to change notification settings - Fork 212
/
Copy path_base_client.py
2081 lines (1771 loc) · 67.6 KB
/
_base_client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from __future__ import annotations
import sys
import json
import time
import uuid
import email
import asyncio
import inspect
import logging
import platform
import warnings
import email.utils
from types import TracebackType
from random import random
from typing import (
TYPE_CHECKING,
Any,
Dict,
Type,
Union,
Generic,
Mapping,
TypeVar,
Iterable,
Iterator,
Optional,
Generator,
AsyncIterator,
cast,
overload,
)
from typing_extensions import Literal, override, get_origin
import anyio
import httpx
import distro
import pydantic
from httpx import URL, Limits
from pydantic import PrivateAttr
from . import _exceptions
from ._qs import Querystring
from ._files import to_httpx_files, async_to_httpx_files
from ._types import (
NOT_GIVEN,
Body,
Omit,
Query,
Headers,
Timeout,
NotGiven,
ResponseT,
Transport,
AnyMapping,
PostParser,
ProxiesTypes,
RequestFiles,
HttpxSendArgs,
AsyncTransport,
RequestOptions,
HttpxRequestFiles,
ModelBuilderProtocol,
)
from ._utils import is_dict, is_list, asyncify, is_given, lru_cache, is_mapping
from ._compat import model_copy, model_dump
from ._models import GenericModel, FinalRequestOptions, validate_type, construct_type
from ._response import (
APIResponse,
BaseAPIResponse,
AsyncAPIResponse,
extract_response_type,
)
from ._constants import (
DEFAULT_TIMEOUT,
MAX_RETRY_DELAY,
DEFAULT_MAX_RETRIES,
INITIAL_RETRY_DELAY,
RAW_RESPONSE_HEADER,
OVERRIDE_CAST_TO_HEADER,
DEFAULT_CONNECTION_LIMITS,
)
from ._streaming import Stream, SSEDecoder, AsyncStream, SSEBytesDecoder
from ._exceptions import (
APIStatusError,
APITimeoutError,
APIConnectionError,
APIResponseValidationError,
)
from ._legacy_response import LegacyAPIResponse
log: logging.Logger = logging.getLogger(__name__)
# TODO: make base page type vars covariant
SyncPageT = TypeVar("SyncPageT", bound="BaseSyncPage[Any]")
AsyncPageT = TypeVar("AsyncPageT", bound="BaseAsyncPage[Any]")
_T = TypeVar("_T")
_T_co = TypeVar("_T_co", covariant=True)
_StreamT = TypeVar("_StreamT", bound=Stream[Any])
_AsyncStreamT = TypeVar("_AsyncStreamT", bound=AsyncStream[Any])
if TYPE_CHECKING:
from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT
else:
try:
from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT
except ImportError:
# taken from https://github.com/encode/httpx/blob/3ba5fe0d7ac70222590e759c31442b1cab263791/httpx/_config.py#L366
HTTPX_DEFAULT_TIMEOUT = Timeout(5.0)
class PageInfo:
"""Stores the necessary information to build the request to retrieve the next page.
Either `url` or `params` must be set.
"""
url: URL | NotGiven
params: Query | NotGiven
@overload
def __init__(
self,
*,
url: URL,
) -> None: ...
@overload
def __init__(
self,
*,
params: Query,
) -> None: ...
def __init__(
self,
*,
url: URL | NotGiven = NOT_GIVEN,
params: Query | NotGiven = NOT_GIVEN,
) -> None:
self.url = url
self.params = params
@override
def __repr__(self) -> str:
if self.url:
return f"{self.__class__.__name__}(url={self.url})"
return f"{self.__class__.__name__}(params={self.params})"
class BasePage(GenericModel, Generic[_T]):
"""
Defines the core interface for pagination.
Type Args:
ModelT: The pydantic model that represents an item in the response.
Methods:
has_next_page(): Check if there is another page available
next_page_info(): Get the necessary information to make a request for the next page
"""
_options: FinalRequestOptions = PrivateAttr()
_model: Type[_T] = PrivateAttr()
def has_next_page(self) -> bool:
items = self._get_page_items()
if not items:
return False
return self.next_page_info() is not None
def next_page_info(self) -> Optional[PageInfo]: ...
def _get_page_items(self) -> Iterable[_T]: # type: ignore[empty-body]
...
def _params_from_url(self, url: URL) -> httpx.QueryParams:
# TODO: do we have to preprocess params here?
return httpx.QueryParams(cast(Any, self._options.params)).merge(url.params)
def _info_to_options(self, info: PageInfo) -> FinalRequestOptions:
options = model_copy(self._options)
options._strip_raw_response_header()
if not isinstance(info.params, NotGiven):
options.params = {**options.params, **info.params}
return options
if not isinstance(info.url, NotGiven):
params = self._params_from_url(info.url)
url = info.url.copy_with(params=params)
options.params = dict(url.params)
options.url = str(url)
return options
raise ValueError("Unexpected PageInfo state")
class BaseSyncPage(BasePage[_T], Generic[_T]):
_client: SyncAPIClient = pydantic.PrivateAttr()
def _set_private_attributes(
self,
client: SyncAPIClient,
model: Type[_T],
options: FinalRequestOptions,
) -> None:
self._model = model
self._client = client
self._options = options
# Pydantic uses a custom `__iter__` method to support casting BaseModels
# to dictionaries. e.g. dict(model).
# As we want to support `for item in page`, this is inherently incompatible
# with the default pydantic behaviour. It is not possible to support both
# use cases at once. Fortunately, this is not a big deal as all other pydantic
# methods should continue to work as expected as there is an alternative method
# to cast a model to a dictionary, model.dict(), which is used internally
# by pydantic.
def __iter__(self) -> Iterator[_T]: # type: ignore
for page in self.iter_pages():
for item in page._get_page_items():
yield item
def iter_pages(self: SyncPageT) -> Iterator[SyncPageT]:
page = self
while True:
yield page
if page.has_next_page():
page = page.get_next_page()
else:
return
def get_next_page(self: SyncPageT) -> SyncPageT:
info = self.next_page_info()
if not info:
raise RuntimeError(
"No next page expected; please check `.has_next_page()` before calling `.get_next_page()`."
)
options = self._info_to_options(info)
return self._client._request_api_list(self._model, page=self.__class__, options=options)
class AsyncPaginator(Generic[_T, AsyncPageT]):
def __init__(
self,
client: AsyncAPIClient,
options: FinalRequestOptions,
page_cls: Type[AsyncPageT],
model: Type[_T],
) -> None:
self._model = model
self._client = client
self._options = options
self._page_cls = page_cls
def __await__(self) -> Generator[Any, None, AsyncPageT]:
return self._get_page().__await__()
async def _get_page(self) -> AsyncPageT:
def _parser(resp: AsyncPageT) -> AsyncPageT:
resp._set_private_attributes(
model=self._model,
options=self._options,
client=self._client,
)
return resp
self._options.post_parser = _parser
return await self._client.request(self._page_cls, self._options)
async def __aiter__(self) -> AsyncIterator[_T]:
# https://github.com/microsoft/pyright/issues/3464
page = cast(
AsyncPageT,
await self, # type: ignore
)
async for item in page:
yield item
class BaseAsyncPage(BasePage[_T], Generic[_T]):
_client: AsyncAPIClient = pydantic.PrivateAttr()
def _set_private_attributes(
self,
model: Type[_T],
client: AsyncAPIClient,
options: FinalRequestOptions,
) -> None:
self._model = model
self._client = client
self._options = options
async def __aiter__(self) -> AsyncIterator[_T]:
async for page in self.iter_pages():
for item in page._get_page_items():
yield item
async def iter_pages(self: AsyncPageT) -> AsyncIterator[AsyncPageT]:
page = self
while True:
yield page
if page.has_next_page():
page = await page.get_next_page()
else:
return
async def get_next_page(self: AsyncPageT) -> AsyncPageT:
info = self.next_page_info()
if not info:
raise RuntimeError(
"No next page expected; please check `.has_next_page()` before calling `.get_next_page()`."
)
options = self._info_to_options(info)
return await self._client._request_api_list(self._model, page=self.__class__, options=options)
_HttpxClientT = TypeVar("_HttpxClientT", bound=Union[httpx.Client, httpx.AsyncClient])
_DefaultStreamT = TypeVar("_DefaultStreamT", bound=Union[Stream[Any], AsyncStream[Any]])
class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]):
_client: _HttpxClientT
_version: str
_base_url: URL
max_retries: int
timeout: Union[float, Timeout, None]
_limits: httpx.Limits
_proxies: ProxiesTypes | None
_transport: Transport | AsyncTransport | None
_strict_response_validation: bool
_idempotency_header: str | None
_default_stream_cls: type[_DefaultStreamT] | None = None
def __init__(
self,
*,
version: str,
base_url: str | URL,
_strict_response_validation: bool,
max_retries: int = DEFAULT_MAX_RETRIES,
timeout: float | Timeout | None = DEFAULT_TIMEOUT,
limits: httpx.Limits,
transport: Transport | AsyncTransport | None,
proxies: ProxiesTypes | None,
custom_headers: Mapping[str, str] | None = None,
custom_query: Mapping[str, object] | None = None,
) -> None:
self._version = version
self._base_url = self._enforce_trailing_slash(URL(base_url))
self.max_retries = max_retries
self.timeout = timeout
self._limits = limits
self._proxies = proxies
self._transport = transport
self._custom_headers = custom_headers or {}
self._custom_query = custom_query or {}
self._strict_response_validation = _strict_response_validation
self._idempotency_header = None
self._platform: Platform | None = None
if max_retries is None: # pyright: ignore[reportUnnecessaryComparison]
raise TypeError(
"max_retries cannot be None. If you want to disable retries, pass `0`; if you want unlimited retries, pass `math.inf` or a very high number; if you want the default behavior, pass `anthropic.DEFAULT_MAX_RETRIES`"
)
def _enforce_trailing_slash(self, url: URL) -> URL:
if url.raw_path.endswith(b"/"):
return url
return url.copy_with(raw_path=url.raw_path + b"/")
def _make_status_error_from_response(
self,
response: httpx.Response,
) -> APIStatusError:
if response.is_closed and not response.is_stream_consumed:
# We can't read the response body as it has been closed
# before it was read. This can happen if an event hook
# raises a status error.
body = None
err_msg = f"Error code: {response.status_code}"
else:
err_text = response.text.strip()
body = err_text
try:
body = json.loads(err_text)
err_msg = f"Error code: {response.status_code} - {body}"
except Exception:
err_msg = err_text or f"Error code: {response.status_code}"
return self._make_status_error(err_msg, body=body, response=response)
def _make_status_error(
self,
err_msg: str,
*,
body: object,
response: httpx.Response,
) -> _exceptions.APIStatusError:
raise NotImplementedError()
def _build_headers(self, options: FinalRequestOptions, *, retries_taken: int = 0) -> httpx.Headers:
custom_headers = options.headers or {}
headers_dict = _merge_mappings(self.default_headers, custom_headers)
self._validate_headers(headers_dict, custom_headers)
# headers are case-insensitive while dictionaries are not.
headers = httpx.Headers(headers_dict)
idempotency_header = self._idempotency_header
if idempotency_header and options.method.lower() != "get" and idempotency_header not in headers:
headers[idempotency_header] = options.idempotency_key or self._idempotency_key()
# Don't set the retry count header if it was already set or removed by the caller. We check
# `custom_headers`, which can contain `Omit()`, instead of `headers` to account for the removal case.
if "x-stainless-retry-count" not in (header.lower() for header in custom_headers):
headers["x-stainless-retry-count"] = str(retries_taken)
return headers
def _prepare_url(self, url: str) -> URL:
"""
Merge a URL argument together with any 'base_url' on the client,
to create the URL used for the outgoing request.
"""
# Copied from httpx's `_merge_url` method.
merge_url = URL(url)
if merge_url.is_relative_url:
merge_raw_path = self.base_url.raw_path + merge_url.raw_path.lstrip(b"/")
return self.base_url.copy_with(raw_path=merge_raw_path)
return merge_url
def _make_sse_decoder(self) -> SSEDecoder | SSEBytesDecoder:
return SSEDecoder()
def _build_request(
self,
options: FinalRequestOptions,
*,
retries_taken: int = 0,
) -> httpx.Request:
if log.isEnabledFor(logging.DEBUG):
log.debug("Request options: %s", model_dump(options, exclude_unset=True))
kwargs: dict[str, Any] = {}
json_data = options.json_data
if options.extra_json is not None:
if json_data is None:
json_data = cast(Body, options.extra_json)
elif is_mapping(json_data):
json_data = _merge_mappings(json_data, options.extra_json)
else:
raise RuntimeError(f"Unexpected JSON data type, {type(json_data)}, cannot merge with `extra_body`")
headers = self._build_headers(options, retries_taken=retries_taken)
params = _merge_mappings(self.default_query, options.params)
content_type = headers.get("Content-Type")
files = options.files
# If the given Content-Type header is multipart/form-data then it
# has to be removed so that httpx can generate the header with
# additional information for us as it has to be in this form
# for the server to be able to correctly parse the request:
# multipart/form-data; boundary=---abc--
if content_type is not None and content_type.startswith("multipart/form-data"):
if "boundary" not in content_type:
# only remove the header if the boundary hasn't been explicitly set
# as the caller doesn't want httpx to come up with their own boundary
headers.pop("Content-Type")
# As we are now sending multipart/form-data instead of application/json
# we need to tell httpx to use it, https://www.python-httpx.org/advanced/clients/#multipart-file-encoding
if json_data:
if not is_dict(json_data):
raise TypeError(
f"Expected query input to be a dictionary for multipart requests but got {type(json_data)} instead."
)
kwargs["data"] = self._serialize_multipartform(json_data)
# httpx determines whether or not to send a "multipart/form-data"
# request based on the truthiness of the "files" argument.
# This gets around that issue by generating a dict value that
# evaluates to true.
#
# https://github.com/encode/httpx/discussions/2399#discussioncomment-3814186
if not files:
files = cast(HttpxRequestFiles, ForceMultipartDict())
prepared_url = self._prepare_url(options.url)
if "_" in prepared_url.host:
# work around https://github.com/encode/httpx/discussions/2880
kwargs["extensions"] = {"sni_hostname": prepared_url.host.replace("_", "-")}
# TODO: report this error to httpx
return self._client.build_request( # pyright: ignore[reportUnknownMemberType]
headers=headers,
timeout=self.timeout if isinstance(options.timeout, NotGiven) else options.timeout,
method=options.method,
url=prepared_url,
# the `Query` type that we use is incompatible with qs'
# `Params` type as it needs to be typed as `Mapping[str, object]`
# so that passing a `TypedDict` doesn't cause an error.
# https://github.com/microsoft/pyright/issues/3526#event-6715453066
params=self.qs.stringify(cast(Mapping[str, Any], params)) if params else None,
json=json_data,
files=files,
**kwargs,
)
def _serialize_multipartform(self, data: Mapping[object, object]) -> dict[str, object]:
items = self.qs.stringify_items(
# TODO: type ignore is required as stringify_items is well typed but we can't be
# well typed without heavy validation.
data, # type: ignore
array_format="brackets",
)
serialized: dict[str, object] = {}
for key, value in items:
existing = serialized.get(key)
if not existing:
serialized[key] = value
continue
# If a value has already been set for this key then that
# means we're sending data like `array[]=[1, 2, 3]` and we
# need to tell httpx that we want to send multiple values with
# the same key which is done by using a list or a tuple.
#
# Note: 2d arrays should never result in the same key at both
# levels so it's safe to assume that if the value is a list,
# it was because we changed it to be a list.
if is_list(existing):
existing.append(value)
else:
serialized[key] = [existing, value]
return serialized
def _maybe_override_cast_to(self, cast_to: type[ResponseT], options: FinalRequestOptions) -> type[ResponseT]:
if not is_given(options.headers):
return cast_to
# make a copy of the headers so we don't mutate user-input
headers = dict(options.headers)
# we internally support defining a temporary header to override the
# default `cast_to` type for use with `.with_raw_response` and `.with_streaming_response`
# see _response.py for implementation details
override_cast_to = headers.pop(OVERRIDE_CAST_TO_HEADER, NOT_GIVEN)
if is_given(override_cast_to):
options.headers = headers
return cast(Type[ResponseT], override_cast_to)
return cast_to
def _should_stream_response_body(self, request: httpx.Request) -> bool:
return request.headers.get(RAW_RESPONSE_HEADER) == "stream" # type: ignore[no-any-return]
def _process_response_data(
self,
*,
data: object,
cast_to: type[ResponseT],
response: httpx.Response,
) -> ResponseT:
if data is None:
return cast(ResponseT, None)
if cast_to is object:
return cast(ResponseT, data)
try:
if inspect.isclass(cast_to) and issubclass(cast_to, ModelBuilderProtocol):
return cast(ResponseT, cast_to.build(response=response, data=data))
if self._strict_response_validation:
return cast(ResponseT, validate_type(type_=cast_to, value=data))
return cast(ResponseT, construct_type(type_=cast_to, value=data))
except pydantic.ValidationError as err:
raise APIResponseValidationError(response=response, body=data) from err
@property
def qs(self) -> Querystring:
return Querystring()
@property
def custom_auth(self) -> httpx.Auth | None:
return None
@property
def auth_headers(self) -> dict[str, str]:
return {}
@property
def default_headers(self) -> dict[str, str | Omit]:
return {
"Accept": "application/json",
"Content-Type": "application/json",
"User-Agent": self.user_agent,
**self.platform_headers(),
**self.auth_headers,
**self._custom_headers,
}
@property
def default_query(self) -> dict[str, object]:
return {
**self._custom_query,
}
def _validate_headers(
self,
headers: Headers, # noqa: ARG002
custom_headers: Headers, # noqa: ARG002
) -> None:
"""Validate the given default headers and custom headers.
Does nothing by default.
"""
return
@property
def user_agent(self) -> str:
return f"{self.__class__.__name__}/Python {self._version}"
@property
def base_url(self) -> URL:
return self._base_url
@base_url.setter
def base_url(self, url: URL | str) -> None:
self._base_url = self._enforce_trailing_slash(url if isinstance(url, URL) else URL(url))
def platform_headers(self) -> Dict[str, str]:
# the actual implementation is in a separate `lru_cache` decorated
# function because adding `lru_cache` to methods will leak memory
# https://github.com/python/cpython/issues/88476
return platform_headers(self._version, platform=self._platform)
def _parse_retry_after_header(self, response_headers: Optional[httpx.Headers] = None) -> float | None:
"""Returns a float of the number of seconds (not milliseconds) to wait after retrying, or None if unspecified.
About the Retry-After header: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
See also https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After#syntax
"""
if response_headers is None:
return None
# First, try the non-standard `retry-after-ms` header for milliseconds,
# which is more precise than integer-seconds `retry-after`
try:
retry_ms_header = response_headers.get("retry-after-ms", None)
return float(retry_ms_header) / 1000
except (TypeError, ValueError):
pass
# Next, try parsing `retry-after` header as seconds (allowing nonstandard floats).
retry_header = response_headers.get("retry-after")
try:
# note: the spec indicates that this should only ever be an integer
# but if someone sends a float there's no reason for us to not respect it
return float(retry_header)
except (TypeError, ValueError):
pass
# Last, try parsing `retry-after` as a date.
retry_date_tuple = email.utils.parsedate_tz(retry_header)
if retry_date_tuple is None:
return None
retry_date = email.utils.mktime_tz(retry_date_tuple)
return float(retry_date - time.time())
def _calculate_retry_timeout(
self,
remaining_retries: int,
options: FinalRequestOptions,
response_headers: Optional[httpx.Headers] = None,
) -> float:
max_retries = options.get_max_retries(self.max_retries)
# If the API asks us to wait a certain amount of time (and it's a reasonable amount), just do what it says.
retry_after = self._parse_retry_after_header(response_headers)
if retry_after is not None and 0 < retry_after <= 60:
return retry_after
# Also cap retry count to 1000 to avoid any potential overflows with `pow`
nb_retries = min(max_retries - remaining_retries, 1000)
# Apply exponential backoff, but not more than the max.
sleep_seconds = min(INITIAL_RETRY_DELAY * pow(2.0, nb_retries), MAX_RETRY_DELAY)
# Apply some jitter, plus-or-minus half a second.
jitter = 1 - 0.25 * random()
timeout = sleep_seconds * jitter
return timeout if timeout >= 0 else 0
def _should_retry(self, response: httpx.Response) -> bool:
# Note: this is not a standard header
should_retry_header = response.headers.get("x-should-retry")
# If the server explicitly says whether or not to retry, obey.
if should_retry_header == "true":
log.debug("Retrying as header `x-should-retry` is set to `true`")
return True
if should_retry_header == "false":
log.debug("Not retrying as header `x-should-retry` is set to `false`")
return False
# Retry on request timeouts.
if response.status_code == 408:
log.debug("Retrying due to status code %i", response.status_code)
return True
# Retry on lock timeouts.
if response.status_code == 409:
log.debug("Retrying due to status code %i", response.status_code)
return True
# Retry on rate limits.
if response.status_code == 429:
log.debug("Retrying due to status code %i", response.status_code)
return True
# Retry internal errors.
if response.status_code >= 500:
log.debug("Retrying due to status code %i", response.status_code)
return True
log.debug("Not retrying")
return False
def _idempotency_key(self) -> str:
return f"stainless-python-retry-{uuid.uuid4()}"
class _DefaultHttpxClient(httpx.Client):
def __init__(self, **kwargs: Any) -> None:
kwargs.setdefault("timeout", DEFAULT_TIMEOUT)
kwargs.setdefault("limits", DEFAULT_CONNECTION_LIMITS)
kwargs.setdefault("follow_redirects", True)
super().__init__(**kwargs)
if TYPE_CHECKING:
DefaultHttpxClient = httpx.Client
"""An alias to `httpx.Client` that provides the same defaults that this SDK
uses internally.
This is useful because overriding the `http_client` with your own instance of
`httpx.Client` will result in httpx's defaults being used, not ours.
"""
else:
DefaultHttpxClient = _DefaultHttpxClient
class SyncHttpxClientWrapper(DefaultHttpxClient):
def __del__(self) -> None:
if self.is_closed:
return
try:
self.close()
except Exception:
pass
class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]):
_client: httpx.Client
_default_stream_cls: type[Stream[Any]] | None = None
def __init__(
self,
*,
version: str,
base_url: str | URL,
max_retries: int = DEFAULT_MAX_RETRIES,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
transport: Transport | None = None,
proxies: ProxiesTypes | None = None,
limits: Limits | None = None,
http_client: httpx.Client | None = None,
custom_headers: Mapping[str, str] | None = None,
custom_query: Mapping[str, object] | None = None,
_strict_response_validation: bool,
) -> None:
kwargs: dict[str, Any] = {}
if limits is not None:
warnings.warn(
"The `connection_pool_limits` argument is deprecated. The `http_client` argument should be passed instead",
category=DeprecationWarning,
stacklevel=3,
)
if http_client is not None:
raise ValueError("The `http_client` argument is mutually exclusive with `connection_pool_limits`")
else:
limits = DEFAULT_CONNECTION_LIMITS
if transport is not None:
kwargs["transport"] = transport
warnings.warn(
"The `transport` argument is deprecated. The `http_client` argument should be passed instead",
category=DeprecationWarning,
stacklevel=3,
)
if http_client is not None:
raise ValueError("The `http_client` argument is mutually exclusive with `transport`")
if proxies is not None:
kwargs["proxies"] = proxies
warnings.warn(
"The `proxies` argument is deprecated. The `http_client` argument should be passed instead",
category=DeprecationWarning,
stacklevel=3,
)
if http_client is not None:
raise ValueError("The `http_client` argument is mutually exclusive with `proxies`")
if not is_given(timeout):
# if the user passed in a custom http client with a non-default
# timeout set then we use that timeout.
#
# note: there is an edge case here where the user passes in a client
# where they've explicitly set the timeout to match the default timeout
# as this check is structural, meaning that we'll think they didn't
# pass in a timeout and will ignore it
if http_client and http_client.timeout != HTTPX_DEFAULT_TIMEOUT:
timeout = http_client.timeout
else:
timeout = DEFAULT_TIMEOUT
if http_client is not None and not isinstance(http_client, httpx.Client): # pyright: ignore[reportUnnecessaryIsInstance]
raise TypeError(
f"Invalid `http_client` argument; Expected an instance of `httpx.Client` but got {type(http_client)}"
)
super().__init__(
version=version,
limits=limits,
# cast to a valid type because mypy doesn't understand our type narrowing
timeout=cast(Timeout, timeout),
proxies=proxies,
base_url=base_url,
transport=transport,
max_retries=max_retries,
custom_query=custom_query,
custom_headers=custom_headers,
_strict_response_validation=_strict_response_validation,
)
self._client = http_client or SyncHttpxClientWrapper(
base_url=base_url,
# cast to a valid type because mypy doesn't understand our type narrowing
timeout=cast(Timeout, timeout),
limits=limits,
follow_redirects=True,
**kwargs, # type: ignore
)
def is_closed(self) -> bool:
return self._client.is_closed
def close(self) -> None:
"""Close the underlying HTTPX client.
The client will *not* be usable after this.
"""
# If an error is thrown while constructing a client, self._client
# may not be present
if hasattr(self, "_client"):
self._client.close()
def __enter__(self: _T) -> _T:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.close()
def _prepare_options(
self,
options: FinalRequestOptions, # noqa: ARG002
) -> FinalRequestOptions:
"""Hook for mutating the given options"""
return options
def _prepare_request(
self,
request: httpx.Request, # noqa: ARG002
) -> None:
"""This method is used as a callback for mutating the `Request` object
after it has been constructed.
This is useful for cases where you want to add certain headers based off of
the request properties, e.g. `url`, `method` etc.
"""
return None
@overload
def request(
self,
cast_to: Type[ResponseT],
options: FinalRequestOptions,
remaining_retries: Optional[int] = None,
*,
stream: Literal[True],
stream_cls: Type[_StreamT],
) -> _StreamT: ...
@overload
def request(
self,
cast_to: Type[ResponseT],
options: FinalRequestOptions,
remaining_retries: Optional[int] = None,
*,
stream: Literal[False] = False,
) -> ResponseT: ...
@overload
def request(
self,
cast_to: Type[ResponseT],
options: FinalRequestOptions,
remaining_retries: Optional[int] = None,
*,
stream: bool = False,
stream_cls: Type[_StreamT] | None = None,
) -> ResponseT | _StreamT: ...
def request(
self,
cast_to: Type[ResponseT],
options: FinalRequestOptions,
remaining_retries: Optional[int] = None,
*,
stream: bool = False,
stream_cls: type[_StreamT] | None = None,
) -> ResponseT | _StreamT:
if remaining_retries is not None:
retries_taken = options.get_max_retries(self.max_retries) - remaining_retries
else:
retries_taken = 0
return self._request(
cast_to=cast_to,
options=options,
stream=stream,
stream_cls=stream_cls,
retries_taken=retries_taken,
)
def _request(
self,
*,
cast_to: Type[ResponseT],
options: FinalRequestOptions,
retries_taken: int,
stream: bool,
stream_cls: type[_StreamT] | None,
) -> ResponseT | _StreamT:
# create a copy of the options we were given so that if the
# options are mutated later & we then retry, the retries are
# given the original options
input_options = model_copy(options)
cast_to = self._maybe_override_cast_to(cast_to, options)
options = self._prepare_options(options)
remaining_retries = options.get_max_retries(self.max_retries) - retries_taken
request = self._build_request(options, retries_taken=retries_taken)
self._prepare_request(request)
kwargs: HttpxSendArgs = {}
if self.custom_auth is not None:
kwargs["auth"] = self.custom_auth
log.debug("Sending HTTP Request: %s %s", request.method, request.url)
try:
response = self._client.send(
request,
stream=stream or self._should_stream_response_body(request=request),
**kwargs,
)
except httpx.TimeoutException as err: