Initial commit
This commit is contained in:
@@ -0,0 +1,39 @@
|
||||
from .connection import AsyncHTTPConnection
|
||||
from .connection_pool import AsyncConnectionPool
|
||||
from .http11 import AsyncHTTP11Connection
|
||||
from .http_proxy import AsyncHTTPProxy
|
||||
from .interfaces import AsyncConnectionInterface
|
||||
|
||||
try:
|
||||
from .http2 import AsyncHTTP2Connection
|
||||
except ImportError: # pragma: nocover
|
||||
|
||||
class AsyncHTTP2Connection: # type: ignore
|
||||
def __init__(self, *args, **kwargs) -> None: # type: ignore
|
||||
raise RuntimeError(
|
||||
"Attempted to use http2 support, but the `h2` package is not "
|
||||
"installed. Use 'pip install httpcore[http2]'."
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
from .socks_proxy import AsyncSOCKSProxy
|
||||
except ImportError: # pragma: nocover
|
||||
|
||||
class AsyncSOCKSProxy: # type: ignore
|
||||
def __init__(self, *args, **kwargs) -> None: # type: ignore
|
||||
raise RuntimeError(
|
||||
"Attempted to use SOCKS support, but the `socksio` package is not "
|
||||
"installed. Use 'pip install httpcore[socks]'."
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AsyncHTTPConnection",
|
||||
"AsyncConnectionPool",
|
||||
"AsyncHTTPProxy",
|
||||
"AsyncHTTP11Connection",
|
||||
"AsyncHTTP2Connection",
|
||||
"AsyncConnectionInterface",
|
||||
"AsyncSOCKSProxy",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,222 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
import ssl
|
||||
import types
|
||||
import typing
|
||||
|
||||
from .._backends.auto import AutoBackend
|
||||
from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
|
||||
from .._exceptions import ConnectError, ConnectTimeout
|
||||
from .._models import Origin, Request, Response
|
||||
from .._ssl import default_ssl_context
|
||||
from .._synchronization import AsyncLock
|
||||
from .._trace import Trace
|
||||
from .http11 import AsyncHTTP11Connection
|
||||
from .interfaces import AsyncConnectionInterface
|
||||
|
||||
RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc.
|
||||
|
||||
|
||||
logger = logging.getLogger("httpcore.connection")
|
||||
|
||||
|
||||
def exponential_backoff(factor: float) -> typing.Iterator[float]:
|
||||
"""
|
||||
Generate a geometric sequence that has a ratio of 2 and starts with 0.
|
||||
|
||||
For example:
|
||||
- `factor = 2`: `0, 2, 4, 8, 16, 32, 64, ...`
|
||||
- `factor = 3`: `0, 3, 6, 12, 24, 48, 96, ...`
|
||||
"""
|
||||
yield 0
|
||||
for n in itertools.count():
|
||||
yield factor * 2**n
|
||||
|
||||
|
||||
class AsyncHTTPConnection(AsyncConnectionInterface):
|
||||
def __init__(
|
||||
self,
|
||||
origin: Origin,
|
||||
ssl_context: ssl.SSLContext | None = None,
|
||||
keepalive_expiry: float | None = None,
|
||||
http1: bool = True,
|
||||
http2: bool = False,
|
||||
retries: int = 0,
|
||||
local_address: str | None = None,
|
||||
uds: str | None = None,
|
||||
network_backend: AsyncNetworkBackend | None = None,
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
) -> None:
|
||||
self._origin = origin
|
||||
self._ssl_context = ssl_context
|
||||
self._keepalive_expiry = keepalive_expiry
|
||||
self._http1 = http1
|
||||
self._http2 = http2
|
||||
self._retries = retries
|
||||
self._local_address = local_address
|
||||
self._uds = uds
|
||||
|
||||
self._network_backend: AsyncNetworkBackend = (
|
||||
AutoBackend() if network_backend is None else network_backend
|
||||
)
|
||||
self._connection: AsyncConnectionInterface | None = None
|
||||
self._connect_failed: bool = False
|
||||
self._request_lock = AsyncLock()
|
||||
self._socket_options = socket_options
|
||||
|
||||
async def handle_async_request(self, request: Request) -> Response:
|
||||
if not self.can_handle_request(request.url.origin):
|
||||
raise RuntimeError(
|
||||
f"Attempted to send request to {request.url.origin} on connection to {self._origin}"
|
||||
)
|
||||
|
||||
try:
|
||||
async with self._request_lock:
|
||||
if self._connection is None:
|
||||
stream = await self._connect(request)
|
||||
|
||||
ssl_object = stream.get_extra_info("ssl_object")
|
||||
http2_negotiated = (
|
||||
ssl_object is not None
|
||||
and ssl_object.selected_alpn_protocol() == "h2"
|
||||
)
|
||||
if http2_negotiated or (self._http2 and not self._http1):
|
||||
from .http2 import AsyncHTTP2Connection
|
||||
|
||||
self._connection = AsyncHTTP2Connection(
|
||||
origin=self._origin,
|
||||
stream=stream,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
)
|
||||
else:
|
||||
self._connection = AsyncHTTP11Connection(
|
||||
origin=self._origin,
|
||||
stream=stream,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
)
|
||||
except BaseException as exc:
|
||||
self._connect_failed = True
|
||||
raise exc
|
||||
|
||||
return await self._connection.handle_async_request(request)
|
||||
|
||||
async def _connect(self, request: Request) -> AsyncNetworkStream:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
sni_hostname = request.extensions.get("sni_hostname", None)
|
||||
timeout = timeouts.get("connect", None)
|
||||
|
||||
retries_left = self._retries
|
||||
delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR)
|
||||
|
||||
while True:
|
||||
try:
|
||||
if self._uds is None:
|
||||
kwargs = {
|
||||
"host": self._origin.host.decode("ascii"),
|
||||
"port": self._origin.port,
|
||||
"local_address": self._local_address,
|
||||
"timeout": timeout,
|
||||
"socket_options": self._socket_options,
|
||||
}
|
||||
async with Trace("connect_tcp", logger, request, kwargs) as trace:
|
||||
stream = await self._network_backend.connect_tcp(**kwargs)
|
||||
trace.return_value = stream
|
||||
else:
|
||||
kwargs = {
|
||||
"path": self._uds,
|
||||
"timeout": timeout,
|
||||
"socket_options": self._socket_options,
|
||||
}
|
||||
async with Trace(
|
||||
"connect_unix_socket", logger, request, kwargs
|
||||
) as trace:
|
||||
stream = await self._network_backend.connect_unix_socket(
|
||||
**kwargs
|
||||
)
|
||||
trace.return_value = stream
|
||||
|
||||
if self._origin.scheme in (b"https", b"wss"):
|
||||
ssl_context = (
|
||||
default_ssl_context()
|
||||
if self._ssl_context is None
|
||||
else self._ssl_context
|
||||
)
|
||||
alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"]
|
||||
ssl_context.set_alpn_protocols(alpn_protocols)
|
||||
|
||||
kwargs = {
|
||||
"ssl_context": ssl_context,
|
||||
"server_hostname": sni_hostname
|
||||
or self._origin.host.decode("ascii"),
|
||||
"timeout": timeout,
|
||||
}
|
||||
async with Trace("start_tls", logger, request, kwargs) as trace:
|
||||
stream = await stream.start_tls(**kwargs)
|
||||
trace.return_value = stream
|
||||
return stream
|
||||
except (ConnectError, ConnectTimeout):
|
||||
if retries_left <= 0:
|
||||
raise
|
||||
retries_left -= 1
|
||||
delay = next(delays)
|
||||
async with Trace("retry", logger, request, kwargs) as trace:
|
||||
await self._network_backend.sleep(delay)
|
||||
|
||||
def can_handle_request(self, origin: Origin) -> bool:
|
||||
return origin == self._origin
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if self._connection is not None:
|
||||
async with Trace("close", logger, None, {}):
|
||||
await self._connection.aclose()
|
||||
|
||||
def is_available(self) -> bool:
|
||||
if self._connection is None:
|
||||
# If HTTP/2 support is enabled, and the resulting connection could
|
||||
# end up as HTTP/2 then we should indicate the connection as being
|
||||
# available to service multiple requests.
|
||||
return (
|
||||
self._http2
|
||||
and (self._origin.scheme == b"https" or not self._http1)
|
||||
and not self._connect_failed
|
||||
)
|
||||
return self._connection.is_available()
|
||||
|
||||
def has_expired(self) -> bool:
|
||||
if self._connection is None:
|
||||
return self._connect_failed
|
||||
return self._connection.has_expired()
|
||||
|
||||
def is_idle(self) -> bool:
|
||||
if self._connection is None:
|
||||
return self._connect_failed
|
||||
return self._connection.is_idle()
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
if self._connection is None:
|
||||
return self._connect_failed
|
||||
return self._connection.is_closed()
|
||||
|
||||
def info(self) -> str:
|
||||
if self._connection is None:
|
||||
return "CONNECTION FAILED" if self._connect_failed else "CONNECTING"
|
||||
return self._connection.info()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__} [{self.info()}]>"
|
||||
|
||||
# These context managers are not used in the standard flow, but are
|
||||
# useful for testing or working with connection instances directly.
|
||||
|
||||
async def __aenter__(self) -> AsyncHTTPConnection:
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None = None,
|
||||
exc_value: BaseException | None = None,
|
||||
traceback: types.TracebackType | None = None,
|
||||
) -> None:
|
||||
await self.aclose()
|
||||
@@ -0,0 +1,420 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ssl
|
||||
import sys
|
||||
import types
|
||||
import typing
|
||||
|
||||
from .._backends.auto import AutoBackend
|
||||
from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend
|
||||
from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol
|
||||
from .._models import Origin, Proxy, Request, Response
|
||||
from .._synchronization import AsyncEvent, AsyncShieldCancellation, AsyncThreadLock
|
||||
from .connection import AsyncHTTPConnection
|
||||
from .interfaces import AsyncConnectionInterface, AsyncRequestInterface
|
||||
|
||||
|
||||
class AsyncPoolRequest:
|
||||
def __init__(self, request: Request) -> None:
|
||||
self.request = request
|
||||
self.connection: AsyncConnectionInterface | None = None
|
||||
self._connection_acquired = AsyncEvent()
|
||||
|
||||
def assign_to_connection(self, connection: AsyncConnectionInterface | None) -> None:
|
||||
self.connection = connection
|
||||
self._connection_acquired.set()
|
||||
|
||||
def clear_connection(self) -> None:
|
||||
self.connection = None
|
||||
self._connection_acquired = AsyncEvent()
|
||||
|
||||
async def wait_for_connection(
|
||||
self, timeout: float | None = None
|
||||
) -> AsyncConnectionInterface:
|
||||
if self.connection is None:
|
||||
await self._connection_acquired.wait(timeout=timeout)
|
||||
assert self.connection is not None
|
||||
return self.connection
|
||||
|
||||
def is_queued(self) -> bool:
|
||||
return self.connection is None
|
||||
|
||||
|
||||
class AsyncConnectionPool(AsyncRequestInterface):
|
||||
"""
|
||||
A connection pool for making HTTP requests.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ssl_context: ssl.SSLContext | None = None,
|
||||
proxy: Proxy | None = None,
|
||||
max_connections: int | None = 10,
|
||||
max_keepalive_connections: int | None = None,
|
||||
keepalive_expiry: float | None = None,
|
||||
http1: bool = True,
|
||||
http2: bool = False,
|
||||
retries: int = 0,
|
||||
local_address: str | None = None,
|
||||
uds: str | None = None,
|
||||
network_backend: AsyncNetworkBackend | None = None,
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
A connection pool for making HTTP requests.
|
||||
|
||||
Parameters:
|
||||
ssl_context: An SSL context to use for verifying connections.
|
||||
If not specified, the default `httpcore.default_ssl_context()`
|
||||
will be used.
|
||||
max_connections: The maximum number of concurrent HTTP connections that
|
||||
the pool should allow. Any attempt to send a request on a pool that
|
||||
would exceed this amount will block until a connection is available.
|
||||
max_keepalive_connections: The maximum number of idle HTTP connections
|
||||
that will be maintained in the pool.
|
||||
keepalive_expiry: The duration in seconds that an idle HTTP connection
|
||||
may be maintained for before being expired from the pool.
|
||||
http1: A boolean indicating if HTTP/1.1 requests should be supported
|
||||
by the connection pool. Defaults to True.
|
||||
http2: A boolean indicating if HTTP/2 requests should be supported by
|
||||
the connection pool. Defaults to False.
|
||||
retries: The maximum number of retries when trying to establish a
|
||||
connection.
|
||||
local_address: Local address to connect from. Can also be used to connect
|
||||
using a particular address family. Using `local_address="0.0.0.0"`
|
||||
will connect using an `AF_INET` address (IPv4), while using
|
||||
`local_address="::"` will connect using an `AF_INET6` address (IPv6).
|
||||
uds: Path to a Unix Domain Socket to use instead of TCP sockets.
|
||||
network_backend: A backend instance to use for handling network I/O.
|
||||
socket_options: Socket options that have to be included
|
||||
in the TCP socket when the connection was established.
|
||||
"""
|
||||
self._ssl_context = ssl_context
|
||||
self._proxy = proxy
|
||||
self._max_connections = (
|
||||
sys.maxsize if max_connections is None else max_connections
|
||||
)
|
||||
self._max_keepalive_connections = (
|
||||
sys.maxsize
|
||||
if max_keepalive_connections is None
|
||||
else max_keepalive_connections
|
||||
)
|
||||
self._max_keepalive_connections = min(
|
||||
self._max_connections, self._max_keepalive_connections
|
||||
)
|
||||
|
||||
self._keepalive_expiry = keepalive_expiry
|
||||
self._http1 = http1
|
||||
self._http2 = http2
|
||||
self._retries = retries
|
||||
self._local_address = local_address
|
||||
self._uds = uds
|
||||
|
||||
self._network_backend = (
|
||||
AutoBackend() if network_backend is None else network_backend
|
||||
)
|
||||
self._socket_options = socket_options
|
||||
|
||||
# The mutable state on a connection pool is the queue of incoming requests,
|
||||
# and the set of connections that are servicing those requests.
|
||||
self._connections: list[AsyncConnectionInterface] = []
|
||||
self._requests: list[AsyncPoolRequest] = []
|
||||
|
||||
# We only mutate the state of the connection pool within an 'optional_thread_lock'
|
||||
# context. This holds a threading lock unless we're running in async mode,
|
||||
# in which case it is a no-op.
|
||||
self._optional_thread_lock = AsyncThreadLock()
|
||||
|
||||
def create_connection(self, origin: Origin) -> AsyncConnectionInterface:
|
||||
if self._proxy is not None:
|
||||
if self._proxy.url.scheme in (b"socks5", b"socks5h"):
|
||||
from .socks_proxy import AsyncSocks5Connection
|
||||
|
||||
return AsyncSocks5Connection(
|
||||
proxy_origin=self._proxy.url.origin,
|
||||
proxy_auth=self._proxy.auth,
|
||||
remote_origin=origin,
|
||||
ssl_context=self._ssl_context,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
http1=self._http1,
|
||||
http2=self._http2,
|
||||
network_backend=self._network_backend,
|
||||
)
|
||||
elif origin.scheme == b"http":
|
||||
from .http_proxy import AsyncForwardHTTPConnection
|
||||
|
||||
return AsyncForwardHTTPConnection(
|
||||
proxy_origin=self._proxy.url.origin,
|
||||
proxy_headers=self._proxy.headers,
|
||||
proxy_ssl_context=self._proxy.ssl_context,
|
||||
remote_origin=origin,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
network_backend=self._network_backend,
|
||||
)
|
||||
from .http_proxy import AsyncTunnelHTTPConnection
|
||||
|
||||
return AsyncTunnelHTTPConnection(
|
||||
proxy_origin=self._proxy.url.origin,
|
||||
proxy_headers=self._proxy.headers,
|
||||
proxy_ssl_context=self._proxy.ssl_context,
|
||||
remote_origin=origin,
|
||||
ssl_context=self._ssl_context,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
http1=self._http1,
|
||||
http2=self._http2,
|
||||
network_backend=self._network_backend,
|
||||
)
|
||||
|
||||
return AsyncHTTPConnection(
|
||||
origin=origin,
|
||||
ssl_context=self._ssl_context,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
http1=self._http1,
|
||||
http2=self._http2,
|
||||
retries=self._retries,
|
||||
local_address=self._local_address,
|
||||
uds=self._uds,
|
||||
network_backend=self._network_backend,
|
||||
socket_options=self._socket_options,
|
||||
)
|
||||
|
||||
@property
|
||||
def connections(self) -> list[AsyncConnectionInterface]:
|
||||
"""
|
||||
Return a list of the connections currently in the pool.
|
||||
|
||||
For example:
|
||||
|
||||
```python
|
||||
>>> pool.connections
|
||||
[
|
||||
<AsyncHTTPConnection ['https://example.com:443', HTTP/1.1, ACTIVE, Request Count: 6]>,
|
||||
<AsyncHTTPConnection ['https://example.com:443', HTTP/1.1, IDLE, Request Count: 9]> ,
|
||||
<AsyncHTTPConnection ['http://example.com:80', HTTP/1.1, IDLE, Request Count: 1]>,
|
||||
]
|
||||
```
|
||||
"""
|
||||
return list(self._connections)
|
||||
|
||||
async def handle_async_request(self, request: Request) -> Response:
|
||||
"""
|
||||
Send an HTTP request, and return an HTTP response.
|
||||
|
||||
This is the core implementation that is called into by `.request()` or `.stream()`.
|
||||
"""
|
||||
scheme = request.url.scheme.decode()
|
||||
if scheme == "":
|
||||
raise UnsupportedProtocol(
|
||||
"Request URL is missing an 'http://' or 'https://' protocol."
|
||||
)
|
||||
if scheme not in ("http", "https", "ws", "wss"):
|
||||
raise UnsupportedProtocol(
|
||||
f"Request URL has an unsupported protocol '{scheme}://'."
|
||||
)
|
||||
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
timeout = timeouts.get("pool", None)
|
||||
|
||||
with self._optional_thread_lock:
|
||||
# Add the incoming request to our request queue.
|
||||
pool_request = AsyncPoolRequest(request)
|
||||
self._requests.append(pool_request)
|
||||
|
||||
try:
|
||||
while True:
|
||||
with self._optional_thread_lock:
|
||||
# Assign incoming requests to available connections,
|
||||
# closing or creating new connections as required.
|
||||
closing = self._assign_requests_to_connections()
|
||||
await self._close_connections(closing)
|
||||
|
||||
# Wait until this request has an assigned connection.
|
||||
connection = await pool_request.wait_for_connection(timeout=timeout)
|
||||
|
||||
try:
|
||||
# Send the request on the assigned connection.
|
||||
response = await connection.handle_async_request(
|
||||
pool_request.request
|
||||
)
|
||||
except ConnectionNotAvailable:
|
||||
# In some cases a connection may initially be available to
|
||||
# handle a request, but then become unavailable.
|
||||
#
|
||||
# In this case we clear the connection and try again.
|
||||
pool_request.clear_connection()
|
||||
else:
|
||||
break # pragma: nocover
|
||||
|
||||
except BaseException as exc:
|
||||
with self._optional_thread_lock:
|
||||
# For any exception or cancellation we remove the request from
|
||||
# the queue, and then re-assign requests to connections.
|
||||
self._requests.remove(pool_request)
|
||||
closing = self._assign_requests_to_connections()
|
||||
|
||||
await self._close_connections(closing)
|
||||
raise exc from None
|
||||
|
||||
# Return the response. Note that in this case we still have to manage
|
||||
# the point at which the response is closed.
|
||||
assert isinstance(response.stream, typing.AsyncIterable)
|
||||
return Response(
|
||||
status=response.status,
|
||||
headers=response.headers,
|
||||
content=PoolByteStream(
|
||||
stream=response.stream, pool_request=pool_request, pool=self
|
||||
),
|
||||
extensions=response.extensions,
|
||||
)
|
||||
|
||||
def _assign_requests_to_connections(self) -> list[AsyncConnectionInterface]:
|
||||
"""
|
||||
Manage the state of the connection pool, assigning incoming
|
||||
requests to connections as available.
|
||||
|
||||
Called whenever a new request is added or removed from the pool.
|
||||
|
||||
Any closing connections are returned, allowing the I/O for closing
|
||||
those connections to be handled seperately.
|
||||
"""
|
||||
closing_connections = []
|
||||
|
||||
# First we handle cleaning up any connections that are closed,
|
||||
# have expired their keep-alive, or surplus idle connections.
|
||||
for connection in list(self._connections):
|
||||
if connection.is_closed():
|
||||
# log: "removing closed connection"
|
||||
self._connections.remove(connection)
|
||||
elif connection.has_expired():
|
||||
# log: "closing expired connection"
|
||||
self._connections.remove(connection)
|
||||
closing_connections.append(connection)
|
||||
elif (
|
||||
connection.is_idle()
|
||||
and len([connection.is_idle() for connection in self._connections])
|
||||
> self._max_keepalive_connections
|
||||
):
|
||||
# log: "closing idle connection"
|
||||
self._connections.remove(connection)
|
||||
closing_connections.append(connection)
|
||||
|
||||
# Assign queued requests to connections.
|
||||
queued_requests = [request for request in self._requests if request.is_queued()]
|
||||
for pool_request in queued_requests:
|
||||
origin = pool_request.request.url.origin
|
||||
available_connections = [
|
||||
connection
|
||||
for connection in self._connections
|
||||
if connection.can_handle_request(origin) and connection.is_available()
|
||||
]
|
||||
idle_connections = [
|
||||
connection for connection in self._connections if connection.is_idle()
|
||||
]
|
||||
|
||||
# There are three cases for how we may be able to handle the request:
|
||||
#
|
||||
# 1. There is an existing connection that can handle the request.
|
||||
# 2. We can create a new connection to handle the request.
|
||||
# 3. We can close an idle connection and then create a new connection
|
||||
# to handle the request.
|
||||
if available_connections:
|
||||
# log: "reusing existing connection"
|
||||
connection = available_connections[0]
|
||||
pool_request.assign_to_connection(connection)
|
||||
elif len(self._connections) < self._max_connections:
|
||||
# log: "creating new connection"
|
||||
connection = self.create_connection(origin)
|
||||
self._connections.append(connection)
|
||||
pool_request.assign_to_connection(connection)
|
||||
elif idle_connections:
|
||||
# log: "closing idle connection"
|
||||
connection = idle_connections[0]
|
||||
self._connections.remove(connection)
|
||||
closing_connections.append(connection)
|
||||
# log: "creating new connection"
|
||||
connection = self.create_connection(origin)
|
||||
self._connections.append(connection)
|
||||
pool_request.assign_to_connection(connection)
|
||||
|
||||
return closing_connections
|
||||
|
||||
async def _close_connections(self, closing: list[AsyncConnectionInterface]) -> None:
|
||||
# Close connections which have been removed from the pool.
|
||||
with AsyncShieldCancellation():
|
||||
for connection in closing:
|
||||
await connection.aclose()
|
||||
|
||||
async def aclose(self) -> None:
|
||||
# Explicitly close the connection pool.
|
||||
# Clears all existing requests and connections.
|
||||
with self._optional_thread_lock:
|
||||
closing_connections = list(self._connections)
|
||||
self._connections = []
|
||||
await self._close_connections(closing_connections)
|
||||
|
||||
async def __aenter__(self) -> AsyncConnectionPool:
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None = None,
|
||||
exc_value: BaseException | None = None,
|
||||
traceback: types.TracebackType | None = None,
|
||||
) -> None:
|
||||
await self.aclose()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
with self._optional_thread_lock:
|
||||
request_is_queued = [request.is_queued() for request in self._requests]
|
||||
connection_is_idle = [
|
||||
connection.is_idle() for connection in self._connections
|
||||
]
|
||||
|
||||
num_active_requests = request_is_queued.count(False)
|
||||
num_queued_requests = request_is_queued.count(True)
|
||||
num_active_connections = connection_is_idle.count(False)
|
||||
num_idle_connections = connection_is_idle.count(True)
|
||||
|
||||
requests_info = (
|
||||
f"Requests: {num_active_requests} active, {num_queued_requests} queued"
|
||||
)
|
||||
connection_info = (
|
||||
f"Connections: {num_active_connections} active, {num_idle_connections} idle"
|
||||
)
|
||||
|
||||
return f"<{class_name} [{requests_info} | {connection_info}]>"
|
||||
|
||||
|
||||
class PoolByteStream:
|
||||
def __init__(
|
||||
self,
|
||||
stream: typing.AsyncIterable[bytes],
|
||||
pool_request: AsyncPoolRequest,
|
||||
pool: AsyncConnectionPool,
|
||||
) -> None:
|
||||
self._stream = stream
|
||||
self._pool_request = pool_request
|
||||
self._pool = pool
|
||||
self._closed = False
|
||||
|
||||
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
|
||||
try:
|
||||
async for part in self._stream:
|
||||
yield part
|
||||
except BaseException as exc:
|
||||
await self.aclose()
|
||||
raise exc from None
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if not self._closed:
|
||||
self._closed = True
|
||||
with AsyncShieldCancellation():
|
||||
if hasattr(self._stream, "aclose"):
|
||||
await self._stream.aclose()
|
||||
|
||||
with self._pool._optional_thread_lock:
|
||||
self._pool._requests.remove(self._pool_request)
|
||||
closing = self._pool._assign_requests_to_connections()
|
||||
|
||||
await self._pool._close_connections(closing)
|
||||
@@ -0,0 +1,379 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import logging
|
||||
import ssl
|
||||
import time
|
||||
import types
|
||||
import typing
|
||||
|
||||
import h11
|
||||
|
||||
from .._backends.base import AsyncNetworkStream
|
||||
from .._exceptions import (
|
||||
ConnectionNotAvailable,
|
||||
LocalProtocolError,
|
||||
RemoteProtocolError,
|
||||
WriteError,
|
||||
map_exceptions,
|
||||
)
|
||||
from .._models import Origin, Request, Response
|
||||
from .._synchronization import AsyncLock, AsyncShieldCancellation
|
||||
from .._trace import Trace
|
||||
from .interfaces import AsyncConnectionInterface
|
||||
|
||||
logger = logging.getLogger("httpcore.http11")
|
||||
|
||||
|
||||
# A subset of `h11.Event` types supported by `_send_event`
|
||||
H11SendEvent = typing.Union[
|
||||
h11.Request,
|
||||
h11.Data,
|
||||
h11.EndOfMessage,
|
||||
]
|
||||
|
||||
|
||||
class HTTPConnectionState(enum.IntEnum):
|
||||
NEW = 0
|
||||
ACTIVE = 1
|
||||
IDLE = 2
|
||||
CLOSED = 3
|
||||
|
||||
|
||||
class AsyncHTTP11Connection(AsyncConnectionInterface):
|
||||
READ_NUM_BYTES = 64 * 1024
|
||||
MAX_INCOMPLETE_EVENT_SIZE = 100 * 1024
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
origin: Origin,
|
||||
stream: AsyncNetworkStream,
|
||||
keepalive_expiry: float | None = None,
|
||||
) -> None:
|
||||
self._origin = origin
|
||||
self._network_stream = stream
|
||||
self._keepalive_expiry: float | None = keepalive_expiry
|
||||
self._expire_at: float | None = None
|
||||
self._state = HTTPConnectionState.NEW
|
||||
self._state_lock = AsyncLock()
|
||||
self._request_count = 0
|
||||
self._h11_state = h11.Connection(
|
||||
our_role=h11.CLIENT,
|
||||
max_incomplete_event_size=self.MAX_INCOMPLETE_EVENT_SIZE,
|
||||
)
|
||||
|
||||
async def handle_async_request(self, request: Request) -> Response:
|
||||
if not self.can_handle_request(request.url.origin):
|
||||
raise RuntimeError(
|
||||
f"Attempted to send request to {request.url.origin} on connection "
|
||||
f"to {self._origin}"
|
||||
)
|
||||
|
||||
async with self._state_lock:
|
||||
if self._state in (HTTPConnectionState.NEW, HTTPConnectionState.IDLE):
|
||||
self._request_count += 1
|
||||
self._state = HTTPConnectionState.ACTIVE
|
||||
self._expire_at = None
|
||||
else:
|
||||
raise ConnectionNotAvailable()
|
||||
|
||||
try:
|
||||
kwargs = {"request": request}
|
||||
try:
|
||||
async with Trace(
|
||||
"send_request_headers", logger, request, kwargs
|
||||
) as trace:
|
||||
await self._send_request_headers(**kwargs)
|
||||
async with Trace("send_request_body", logger, request, kwargs) as trace:
|
||||
await self._send_request_body(**kwargs)
|
||||
except WriteError:
|
||||
# If we get a write error while we're writing the request,
|
||||
# then we supress this error and move on to attempting to
|
||||
# read the response. Servers can sometimes close the request
|
||||
# pre-emptively and then respond with a well formed HTTP
|
||||
# error response.
|
||||
pass
|
||||
|
||||
async with Trace(
|
||||
"receive_response_headers", logger, request, kwargs
|
||||
) as trace:
|
||||
(
|
||||
http_version,
|
||||
status,
|
||||
reason_phrase,
|
||||
headers,
|
||||
trailing_data,
|
||||
) = await self._receive_response_headers(**kwargs)
|
||||
trace.return_value = (
|
||||
http_version,
|
||||
status,
|
||||
reason_phrase,
|
||||
headers,
|
||||
)
|
||||
|
||||
network_stream = self._network_stream
|
||||
|
||||
# CONNECT or Upgrade request
|
||||
if (status == 101) or (
|
||||
(request.method == b"CONNECT") and (200 <= status < 300)
|
||||
):
|
||||
network_stream = AsyncHTTP11UpgradeStream(network_stream, trailing_data)
|
||||
|
||||
return Response(
|
||||
status=status,
|
||||
headers=headers,
|
||||
content=HTTP11ConnectionByteStream(self, request),
|
||||
extensions={
|
||||
"http_version": http_version,
|
||||
"reason_phrase": reason_phrase,
|
||||
"network_stream": network_stream,
|
||||
},
|
||||
)
|
||||
except BaseException as exc:
|
||||
with AsyncShieldCancellation():
|
||||
async with Trace("response_closed", logger, request) as trace:
|
||||
await self._response_closed()
|
||||
raise exc
|
||||
|
||||
# Sending the request...
|
||||
|
||||
async def _send_request_headers(self, request: Request) -> None:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
timeout = timeouts.get("write", None)
|
||||
|
||||
with map_exceptions({h11.LocalProtocolError: LocalProtocolError}):
|
||||
event = h11.Request(
|
||||
method=request.method,
|
||||
target=request.url.target,
|
||||
headers=request.headers,
|
||||
)
|
||||
await self._send_event(event, timeout=timeout)
|
||||
|
||||
async def _send_request_body(self, request: Request) -> None:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
timeout = timeouts.get("write", None)
|
||||
|
||||
assert isinstance(request.stream, typing.AsyncIterable)
|
||||
async for chunk in request.stream:
|
||||
event = h11.Data(data=chunk)
|
||||
await self._send_event(event, timeout=timeout)
|
||||
|
||||
await self._send_event(h11.EndOfMessage(), timeout=timeout)
|
||||
|
||||
async def _send_event(self, event: h11.Event, timeout: float | None = None) -> None:
|
||||
bytes_to_send = self._h11_state.send(event)
|
||||
if bytes_to_send is not None:
|
||||
await self._network_stream.write(bytes_to_send, timeout=timeout)
|
||||
|
||||
# Receiving the response...
|
||||
|
||||
async def _receive_response_headers(
|
||||
self, request: Request
|
||||
) -> tuple[bytes, int, bytes, list[tuple[bytes, bytes]], bytes]:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
timeout = timeouts.get("read", None)
|
||||
|
||||
while True:
|
||||
event = await self._receive_event(timeout=timeout)
|
||||
if isinstance(event, h11.Response):
|
||||
break
|
||||
if (
|
||||
isinstance(event, h11.InformationalResponse)
|
||||
and event.status_code == 101
|
||||
):
|
||||
break
|
||||
|
||||
http_version = b"HTTP/" + event.http_version
|
||||
|
||||
# h11 version 0.11+ supports a `raw_items` interface to get the
|
||||
# raw header casing, rather than the enforced lowercase headers.
|
||||
headers = event.headers.raw_items()
|
||||
|
||||
trailing_data, _ = self._h11_state.trailing_data
|
||||
|
||||
return http_version, event.status_code, event.reason, headers, trailing_data
|
||||
|
||||
async def _receive_response_body(
|
||||
self, request: Request
|
||||
) -> typing.AsyncIterator[bytes]:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
timeout = timeouts.get("read", None)
|
||||
|
||||
while True:
|
||||
event = await self._receive_event(timeout=timeout)
|
||||
if isinstance(event, h11.Data):
|
||||
yield bytes(event.data)
|
||||
elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)):
|
||||
break
|
||||
|
||||
async def _receive_event(
|
||||
self, timeout: float | None = None
|
||||
) -> h11.Event | type[h11.PAUSED]:
|
||||
while True:
|
||||
with map_exceptions({h11.RemoteProtocolError: RemoteProtocolError}):
|
||||
event = self._h11_state.next_event()
|
||||
|
||||
if event is h11.NEED_DATA:
|
||||
data = await self._network_stream.read(
|
||||
self.READ_NUM_BYTES, timeout=timeout
|
||||
)
|
||||
|
||||
# If we feed this case through h11 we'll raise an exception like:
|
||||
#
|
||||
# httpcore.RemoteProtocolError: can't handle event type
|
||||
# ConnectionClosed when role=SERVER and state=SEND_RESPONSE
|
||||
#
|
||||
# Which is accurate, but not very informative from an end-user
|
||||
# perspective. Instead we handle this case distinctly and treat
|
||||
# it as a ConnectError.
|
||||
if data == b"" and self._h11_state.their_state == h11.SEND_RESPONSE:
|
||||
msg = "Server disconnected without sending a response."
|
||||
raise RemoteProtocolError(msg)
|
||||
|
||||
self._h11_state.receive_data(data)
|
||||
else:
|
||||
# mypy fails to narrow the type in the above if statement above
|
||||
return event # type: ignore[return-value]
|
||||
|
||||
async def _response_closed(self) -> None:
|
||||
async with self._state_lock:
|
||||
if (
|
||||
self._h11_state.our_state is h11.DONE
|
||||
and self._h11_state.their_state is h11.DONE
|
||||
):
|
||||
self._state = HTTPConnectionState.IDLE
|
||||
self._h11_state.start_next_cycle()
|
||||
if self._keepalive_expiry is not None:
|
||||
now = time.monotonic()
|
||||
self._expire_at = now + self._keepalive_expiry
|
||||
else:
|
||||
await self.aclose()
|
||||
|
||||
# Once the connection is no longer required...
|
||||
|
||||
async def aclose(self) -> None:
|
||||
# Note that this method unilaterally closes the connection, and does
|
||||
# not have any kind of locking in place around it.
|
||||
self._state = HTTPConnectionState.CLOSED
|
||||
await self._network_stream.aclose()
|
||||
|
||||
# The AsyncConnectionInterface methods provide information about the state of
|
||||
# the connection, allowing for a connection pooling implementation to
|
||||
# determine when to reuse and when to close the connection...
|
||||
|
||||
def can_handle_request(self, origin: Origin) -> bool:
|
||||
return origin == self._origin
|
||||
|
||||
def is_available(self) -> bool:
|
||||
# Note that HTTP/1.1 connections in the "NEW" state are not treated as
|
||||
# being "available". The control flow which created the connection will
|
||||
# be able to send an outgoing request, but the connection will not be
|
||||
# acquired from the connection pool for any other request.
|
||||
return self._state == HTTPConnectionState.IDLE
|
||||
|
||||
def has_expired(self) -> bool:
|
||||
now = time.monotonic()
|
||||
keepalive_expired = self._expire_at is not None and now > self._expire_at
|
||||
|
||||
# If the HTTP connection is idle but the socket is readable, then the
|
||||
# only valid state is that the socket is about to return b"", indicating
|
||||
# a server-initiated disconnect.
|
||||
server_disconnected = (
|
||||
self._state == HTTPConnectionState.IDLE
|
||||
and self._network_stream.get_extra_info("is_readable")
|
||||
)
|
||||
|
||||
return keepalive_expired or server_disconnected
|
||||
|
||||
def is_idle(self) -> bool:
|
||||
return self._state == HTTPConnectionState.IDLE
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
return self._state == HTTPConnectionState.CLOSED
|
||||
|
||||
def info(self) -> str:
|
||||
origin = str(self._origin)
|
||||
return (
|
||||
f"{origin!r}, HTTP/1.1, {self._state.name}, "
|
||||
f"Request Count: {self._request_count}"
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
origin = str(self._origin)
|
||||
return (
|
||||
f"<{class_name} [{origin!r}, {self._state.name}, "
|
||||
f"Request Count: {self._request_count}]>"
|
||||
)
|
||||
|
||||
# These context managers are not used in the standard flow, but are
|
||||
# useful for testing or working with connection instances directly.
|
||||
|
||||
async def __aenter__(self) -> AsyncHTTP11Connection:
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None = None,
|
||||
exc_value: BaseException | None = None,
|
||||
traceback: types.TracebackType | None = None,
|
||||
) -> None:
|
||||
await self.aclose()
|
||||
|
||||
|
||||
class HTTP11ConnectionByteStream:
|
||||
def __init__(self, connection: AsyncHTTP11Connection, request: Request) -> None:
|
||||
self._connection = connection
|
||||
self._request = request
|
||||
self._closed = False
|
||||
|
||||
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
|
||||
kwargs = {"request": self._request}
|
||||
try:
|
||||
async with Trace("receive_response_body", logger, self._request, kwargs):
|
||||
async for chunk in self._connection._receive_response_body(**kwargs):
|
||||
yield chunk
|
||||
except BaseException as exc:
|
||||
# If we get an exception while streaming the response,
|
||||
# we want to close the response (and possibly the connection)
|
||||
# before raising that exception.
|
||||
with AsyncShieldCancellation():
|
||||
await self.aclose()
|
||||
raise exc
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if not self._closed:
|
||||
self._closed = True
|
||||
async with Trace("response_closed", logger, self._request):
|
||||
await self._connection._response_closed()
|
||||
|
||||
|
||||
class AsyncHTTP11UpgradeStream(AsyncNetworkStream):
|
||||
def __init__(self, stream: AsyncNetworkStream, leading_data: bytes) -> None:
|
||||
self._stream = stream
|
||||
self._leading_data = leading_data
|
||||
|
||||
async def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
|
||||
if self._leading_data:
|
||||
buffer = self._leading_data[:max_bytes]
|
||||
self._leading_data = self._leading_data[max_bytes:]
|
||||
return buffer
|
||||
else:
|
||||
return await self._stream.read(max_bytes, timeout)
|
||||
|
||||
async def write(self, buffer: bytes, timeout: float | None = None) -> None:
|
||||
await self._stream.write(buffer, timeout)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self._stream.aclose()
|
||||
|
||||
async def start_tls(
|
||||
self,
|
||||
ssl_context: ssl.SSLContext,
|
||||
server_hostname: str | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> AsyncNetworkStream:
|
||||
return await self._stream.start_tls(ssl_context, server_hostname, timeout)
|
||||
|
||||
def get_extra_info(self, info: str) -> typing.Any:
|
||||
return self._stream.get_extra_info(info)
|
||||
@@ -0,0 +1,592 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import logging
|
||||
import time
|
||||
import types
|
||||
import typing
|
||||
|
||||
import h2.config
|
||||
import h2.connection
|
||||
import h2.events
|
||||
import h2.exceptions
|
||||
import h2.settings
|
||||
|
||||
from .._backends.base import AsyncNetworkStream
|
||||
from .._exceptions import (
|
||||
ConnectionNotAvailable,
|
||||
LocalProtocolError,
|
||||
RemoteProtocolError,
|
||||
)
|
||||
from .._models import Origin, Request, Response
|
||||
from .._synchronization import AsyncLock, AsyncSemaphore, AsyncShieldCancellation
|
||||
from .._trace import Trace
|
||||
from .interfaces import AsyncConnectionInterface
|
||||
|
||||
logger = logging.getLogger("httpcore.http2")
|
||||
|
||||
|
||||
def has_body_headers(request: Request) -> bool:
|
||||
return any(
|
||||
k.lower() == b"content-length" or k.lower() == b"transfer-encoding"
|
||||
for k, v in request.headers
|
||||
)
|
||||
|
||||
|
||||
class HTTPConnectionState(enum.IntEnum):
|
||||
ACTIVE = 1
|
||||
IDLE = 2
|
||||
CLOSED = 3
|
||||
|
||||
|
||||
class AsyncHTTP2Connection(AsyncConnectionInterface):
|
||||
READ_NUM_BYTES = 64 * 1024
|
||||
CONFIG = h2.config.H2Configuration(validate_inbound_headers=False)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
origin: Origin,
|
||||
stream: AsyncNetworkStream,
|
||||
keepalive_expiry: float | None = None,
|
||||
):
|
||||
self._origin = origin
|
||||
self._network_stream = stream
|
||||
self._keepalive_expiry: float | None = keepalive_expiry
|
||||
self._h2_state = h2.connection.H2Connection(config=self.CONFIG)
|
||||
self._state = HTTPConnectionState.IDLE
|
||||
self._expire_at: float | None = None
|
||||
self._request_count = 0
|
||||
self._init_lock = AsyncLock()
|
||||
self._state_lock = AsyncLock()
|
||||
self._read_lock = AsyncLock()
|
||||
self._write_lock = AsyncLock()
|
||||
self._sent_connection_init = False
|
||||
self._used_all_stream_ids = False
|
||||
self._connection_error = False
|
||||
|
||||
# Mapping from stream ID to response stream events.
|
||||
self._events: dict[
|
||||
int,
|
||||
list[
|
||||
h2.events.ResponseReceived
|
||||
| h2.events.DataReceived
|
||||
| h2.events.StreamEnded
|
||||
| h2.events.StreamReset,
|
||||
],
|
||||
] = {}
|
||||
|
||||
# Connection terminated events are stored as state since
|
||||
# we need to handle them for all streams.
|
||||
self._connection_terminated: h2.events.ConnectionTerminated | None = None
|
||||
|
||||
self._read_exception: Exception | None = None
|
||||
self._write_exception: Exception | None = None
|
||||
|
||||
async def handle_async_request(self, request: Request) -> Response:
|
||||
if not self.can_handle_request(request.url.origin):
|
||||
# This cannot occur in normal operation, since the connection pool
|
||||
# will only send requests on connections that handle them.
|
||||
# It's in place simply for resilience as a guard against incorrect
|
||||
# usage, for anyone working directly with httpcore connections.
|
||||
raise RuntimeError(
|
||||
f"Attempted to send request to {request.url.origin} on connection "
|
||||
f"to {self._origin}"
|
||||
)
|
||||
|
||||
async with self._state_lock:
|
||||
if self._state in (HTTPConnectionState.ACTIVE, HTTPConnectionState.IDLE):
|
||||
self._request_count += 1
|
||||
self._expire_at = None
|
||||
self._state = HTTPConnectionState.ACTIVE
|
||||
else:
|
||||
raise ConnectionNotAvailable()
|
||||
|
||||
async with self._init_lock:
|
||||
if not self._sent_connection_init:
|
||||
try:
|
||||
sci_kwargs = {"request": request}
|
||||
async with Trace(
|
||||
"send_connection_init", logger, request, sci_kwargs
|
||||
):
|
||||
await self._send_connection_init(**sci_kwargs)
|
||||
except BaseException as exc:
|
||||
with AsyncShieldCancellation():
|
||||
await self.aclose()
|
||||
raise exc
|
||||
|
||||
self._sent_connection_init = True
|
||||
|
||||
# Initially start with just 1 until the remote server provides
|
||||
# its max_concurrent_streams value
|
||||
self._max_streams = 1
|
||||
|
||||
local_settings_max_streams = (
|
||||
self._h2_state.local_settings.max_concurrent_streams
|
||||
)
|
||||
self._max_streams_semaphore = AsyncSemaphore(local_settings_max_streams)
|
||||
|
||||
for _ in range(local_settings_max_streams - self._max_streams):
|
||||
await self._max_streams_semaphore.acquire()
|
||||
|
||||
await self._max_streams_semaphore.acquire()
|
||||
|
||||
try:
|
||||
stream_id = self._h2_state.get_next_available_stream_id()
|
||||
self._events[stream_id] = []
|
||||
except h2.exceptions.NoAvailableStreamIDError: # pragma: nocover
|
||||
self._used_all_stream_ids = True
|
||||
self._request_count -= 1
|
||||
raise ConnectionNotAvailable()
|
||||
|
||||
try:
|
||||
kwargs = {"request": request, "stream_id": stream_id}
|
||||
async with Trace("send_request_headers", logger, request, kwargs):
|
||||
await self._send_request_headers(request=request, stream_id=stream_id)
|
||||
async with Trace("send_request_body", logger, request, kwargs):
|
||||
await self._send_request_body(request=request, stream_id=stream_id)
|
||||
async with Trace(
|
||||
"receive_response_headers", logger, request, kwargs
|
||||
) as trace:
|
||||
status, headers = await self._receive_response(
|
||||
request=request, stream_id=stream_id
|
||||
)
|
||||
trace.return_value = (status, headers)
|
||||
|
||||
return Response(
|
||||
status=status,
|
||||
headers=headers,
|
||||
content=HTTP2ConnectionByteStream(self, request, stream_id=stream_id),
|
||||
extensions={
|
||||
"http_version": b"HTTP/2",
|
||||
"network_stream": self._network_stream,
|
||||
"stream_id": stream_id,
|
||||
},
|
||||
)
|
||||
except BaseException as exc: # noqa: PIE786
|
||||
with AsyncShieldCancellation():
|
||||
kwargs = {"stream_id": stream_id}
|
||||
async with Trace("response_closed", logger, request, kwargs):
|
||||
await self._response_closed(stream_id=stream_id)
|
||||
|
||||
if isinstance(exc, h2.exceptions.ProtocolError):
|
||||
# One case where h2 can raise a protocol error is when a
|
||||
# closed frame has been seen by the state machine.
|
||||
#
|
||||
# This happens when one stream is reading, and encounters
|
||||
# a GOAWAY event. Other flows of control may then raise
|
||||
# a protocol error at any point they interact with the 'h2_state'.
|
||||
#
|
||||
# In this case we'll have stored the event, and should raise
|
||||
# it as a RemoteProtocolError.
|
||||
if self._connection_terminated: # pragma: nocover
|
||||
raise RemoteProtocolError(self._connection_terminated)
|
||||
# If h2 raises a protocol error in some other state then we
|
||||
# must somehow have made a protocol violation.
|
||||
raise LocalProtocolError(exc) # pragma: nocover
|
||||
|
||||
raise exc
|
||||
|
||||
async def _send_connection_init(self, request: Request) -> None:
|
||||
"""
|
||||
The HTTP/2 connection requires some initial setup before we can start
|
||||
using individual request/response streams on it.
|
||||
"""
|
||||
# Need to set these manually here instead of manipulating via
|
||||
# __setitem__() otherwise the H2Connection will emit SettingsUpdate
|
||||
# frames in addition to sending the undesired defaults.
|
||||
self._h2_state.local_settings = h2.settings.Settings(
|
||||
client=True,
|
||||
initial_values={
|
||||
# Disable PUSH_PROMISE frames from the server since we don't do anything
|
||||
# with them for now. Maybe when we support caching?
|
||||
h2.settings.SettingCodes.ENABLE_PUSH: 0,
|
||||
# These two are taken from h2 for safe defaults
|
||||
h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS: 100,
|
||||
h2.settings.SettingCodes.MAX_HEADER_LIST_SIZE: 65536,
|
||||
},
|
||||
)
|
||||
|
||||
# Some websites (*cough* Yahoo *cough*) balk at this setting being
|
||||
# present in the initial handshake since it's not defined in the original
|
||||
# RFC despite the RFC mandating ignoring settings you don't know about.
|
||||
del self._h2_state.local_settings[
|
||||
h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL
|
||||
]
|
||||
|
||||
self._h2_state.initiate_connection()
|
||||
self._h2_state.increment_flow_control_window(2**24)
|
||||
await self._write_outgoing_data(request)
|
||||
|
||||
# Sending the request...
|
||||
|
||||
async def _send_request_headers(self, request: Request, stream_id: int) -> None:
|
||||
"""
|
||||
Send the request headers to a given stream ID.
|
||||
"""
|
||||
end_stream = not has_body_headers(request)
|
||||
|
||||
# In HTTP/2 the ':authority' pseudo-header is used instead of 'Host'.
|
||||
# In order to gracefully handle HTTP/1.1 and HTTP/2 we always require
|
||||
# HTTP/1.1 style headers, and map them appropriately if we end up on
|
||||
# an HTTP/2 connection.
|
||||
authority = [v for k, v in request.headers if k.lower() == b"host"][0]
|
||||
|
||||
headers = [
|
||||
(b":method", request.method),
|
||||
(b":authority", authority),
|
||||
(b":scheme", request.url.scheme),
|
||||
(b":path", request.url.target),
|
||||
] + [
|
||||
(k.lower(), v)
|
||||
for k, v in request.headers
|
||||
if k.lower()
|
||||
not in (
|
||||
b"host",
|
||||
b"transfer-encoding",
|
||||
)
|
||||
]
|
||||
|
||||
self._h2_state.send_headers(stream_id, headers, end_stream=end_stream)
|
||||
self._h2_state.increment_flow_control_window(2**24, stream_id=stream_id)
|
||||
await self._write_outgoing_data(request)
|
||||
|
||||
async def _send_request_body(self, request: Request, stream_id: int) -> None:
|
||||
"""
|
||||
Iterate over the request body sending it to a given stream ID.
|
||||
"""
|
||||
if not has_body_headers(request):
|
||||
return
|
||||
|
||||
assert isinstance(request.stream, typing.AsyncIterable)
|
||||
async for data in request.stream:
|
||||
await self._send_stream_data(request, stream_id, data)
|
||||
await self._send_end_stream(request, stream_id)
|
||||
|
||||
async def _send_stream_data(
|
||||
self, request: Request, stream_id: int, data: bytes
|
||||
) -> None:
|
||||
"""
|
||||
Send a single chunk of data in one or more data frames.
|
||||
"""
|
||||
while data:
|
||||
max_flow = await self._wait_for_outgoing_flow(request, stream_id)
|
||||
chunk_size = min(len(data), max_flow)
|
||||
chunk, data = data[:chunk_size], data[chunk_size:]
|
||||
self._h2_state.send_data(stream_id, chunk)
|
||||
await self._write_outgoing_data(request)
|
||||
|
||||
async def _send_end_stream(self, request: Request, stream_id: int) -> None:
|
||||
"""
|
||||
Send an empty data frame on on a given stream ID with the END_STREAM flag set.
|
||||
"""
|
||||
self._h2_state.end_stream(stream_id)
|
||||
await self._write_outgoing_data(request)
|
||||
|
||||
# Receiving the response...
|
||||
|
||||
async def _receive_response(
|
||||
self, request: Request, stream_id: int
|
||||
) -> tuple[int, list[tuple[bytes, bytes]]]:
|
||||
"""
|
||||
Return the response status code and headers for a given stream ID.
|
||||
"""
|
||||
while True:
|
||||
event = await self._receive_stream_event(request, stream_id)
|
||||
if isinstance(event, h2.events.ResponseReceived):
|
||||
break
|
||||
|
||||
status_code = 200
|
||||
headers = []
|
||||
assert event.headers is not None
|
||||
for k, v in event.headers:
|
||||
if k == b":status":
|
||||
status_code = int(v.decode("ascii", errors="ignore"))
|
||||
elif not k.startswith(b":"):
|
||||
headers.append((k, v))
|
||||
|
||||
return (status_code, headers)
|
||||
|
||||
async def _receive_response_body(
|
||||
self, request: Request, stream_id: int
|
||||
) -> typing.AsyncIterator[bytes]:
|
||||
"""
|
||||
Iterator that returns the bytes of the response body for a given stream ID.
|
||||
"""
|
||||
while True:
|
||||
event = await self._receive_stream_event(request, stream_id)
|
||||
if isinstance(event, h2.events.DataReceived):
|
||||
assert event.flow_controlled_length is not None
|
||||
assert event.data is not None
|
||||
amount = event.flow_controlled_length
|
||||
self._h2_state.acknowledge_received_data(amount, stream_id)
|
||||
await self._write_outgoing_data(request)
|
||||
yield event.data
|
||||
elif isinstance(event, h2.events.StreamEnded):
|
||||
break
|
||||
|
||||
async def _receive_stream_event(
|
||||
self, request: Request, stream_id: int
|
||||
) -> h2.events.ResponseReceived | h2.events.DataReceived | h2.events.StreamEnded:
|
||||
"""
|
||||
Return the next available event for a given stream ID.
|
||||
|
||||
Will read more data from the network if required.
|
||||
"""
|
||||
while not self._events.get(stream_id):
|
||||
await self._receive_events(request, stream_id)
|
||||
event = self._events[stream_id].pop(0)
|
||||
if isinstance(event, h2.events.StreamReset):
|
||||
raise RemoteProtocolError(event)
|
||||
return event
|
||||
|
||||
async def _receive_events(
|
||||
self, request: Request, stream_id: int | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Read some data from the network until we see one or more events
|
||||
for a given stream ID.
|
||||
"""
|
||||
async with self._read_lock:
|
||||
if self._connection_terminated is not None:
|
||||
last_stream_id = self._connection_terminated.last_stream_id
|
||||
if stream_id and last_stream_id and stream_id > last_stream_id:
|
||||
self._request_count -= 1
|
||||
raise ConnectionNotAvailable()
|
||||
raise RemoteProtocolError(self._connection_terminated)
|
||||
|
||||
# This conditional is a bit icky. We don't want to block reading if we've
|
||||
# actually got an event to return for a given stream. We need to do that
|
||||
# check *within* the atomic read lock. Though it also need to be optional,
|
||||
# because when we call it from `_wait_for_outgoing_flow` we *do* want to
|
||||
# block until we've available flow control, event when we have events
|
||||
# pending for the stream ID we're attempting to send on.
|
||||
if stream_id is None or not self._events.get(stream_id):
|
||||
events = await self._read_incoming_data(request)
|
||||
for event in events:
|
||||
if isinstance(event, h2.events.RemoteSettingsChanged):
|
||||
async with Trace(
|
||||
"receive_remote_settings", logger, request
|
||||
) as trace:
|
||||
await self._receive_remote_settings_change(event)
|
||||
trace.return_value = event
|
||||
|
||||
elif isinstance(
|
||||
event,
|
||||
(
|
||||
h2.events.ResponseReceived,
|
||||
h2.events.DataReceived,
|
||||
h2.events.StreamEnded,
|
||||
h2.events.StreamReset,
|
||||
),
|
||||
):
|
||||
if event.stream_id in self._events:
|
||||
self._events[event.stream_id].append(event)
|
||||
|
||||
elif isinstance(event, h2.events.ConnectionTerminated):
|
||||
self._connection_terminated = event
|
||||
|
||||
await self._write_outgoing_data(request)
|
||||
|
||||
async def _receive_remote_settings_change(
|
||||
self, event: h2.events.RemoteSettingsChanged
|
||||
) -> None:
|
||||
max_concurrent_streams = event.changed_settings.get(
|
||||
h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS
|
||||
)
|
||||
if max_concurrent_streams:
|
||||
new_max_streams = min(
|
||||
max_concurrent_streams.new_value,
|
||||
self._h2_state.local_settings.max_concurrent_streams,
|
||||
)
|
||||
if new_max_streams and new_max_streams != self._max_streams:
|
||||
while new_max_streams > self._max_streams:
|
||||
await self._max_streams_semaphore.release()
|
||||
self._max_streams += 1
|
||||
while new_max_streams < self._max_streams:
|
||||
await self._max_streams_semaphore.acquire()
|
||||
self._max_streams -= 1
|
||||
|
||||
async def _response_closed(self, stream_id: int) -> None:
|
||||
await self._max_streams_semaphore.release()
|
||||
del self._events[stream_id]
|
||||
async with self._state_lock:
|
||||
if self._connection_terminated and not self._events:
|
||||
await self.aclose()
|
||||
|
||||
elif self._state == HTTPConnectionState.ACTIVE and not self._events:
|
||||
self._state = HTTPConnectionState.IDLE
|
||||
if self._keepalive_expiry is not None:
|
||||
now = time.monotonic()
|
||||
self._expire_at = now + self._keepalive_expiry
|
||||
if self._used_all_stream_ids: # pragma: nocover
|
||||
await self.aclose()
|
||||
|
||||
async def aclose(self) -> None:
|
||||
# Note that this method unilaterally closes the connection, and does
|
||||
# not have any kind of locking in place around it.
|
||||
self._h2_state.close_connection()
|
||||
self._state = HTTPConnectionState.CLOSED
|
||||
await self._network_stream.aclose()
|
||||
|
||||
# Wrappers around network read/write operations...
|
||||
|
||||
async def _read_incoming_data(self, request: Request) -> list[h2.events.Event]:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
timeout = timeouts.get("read", None)
|
||||
|
||||
if self._read_exception is not None:
|
||||
raise self._read_exception # pragma: nocover
|
||||
|
||||
try:
|
||||
data = await self._network_stream.read(self.READ_NUM_BYTES, timeout)
|
||||
if data == b"":
|
||||
raise RemoteProtocolError("Server disconnected")
|
||||
except Exception as exc:
|
||||
# If we get a network error we should:
|
||||
#
|
||||
# 1. Save the exception and just raise it immediately on any future reads.
|
||||
# (For example, this means that a single read timeout or disconnect will
|
||||
# immediately close all pending streams. Without requiring multiple
|
||||
# sequential timeouts.)
|
||||
# 2. Mark the connection as errored, so that we don't accept any other
|
||||
# incoming requests.
|
||||
self._read_exception = exc
|
||||
self._connection_error = True
|
||||
raise exc
|
||||
|
||||
events: list[h2.events.Event] = self._h2_state.receive_data(data)
|
||||
|
||||
return events
|
||||
|
||||
async def _write_outgoing_data(self, request: Request) -> None:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
timeout = timeouts.get("write", None)
|
||||
|
||||
async with self._write_lock:
|
||||
data_to_send = self._h2_state.data_to_send()
|
||||
|
||||
if self._write_exception is not None:
|
||||
raise self._write_exception # pragma: nocover
|
||||
|
||||
try:
|
||||
await self._network_stream.write(data_to_send, timeout)
|
||||
except Exception as exc: # pragma: nocover
|
||||
# If we get a network error we should:
|
||||
#
|
||||
# 1. Save the exception and just raise it immediately on any future write.
|
||||
# (For example, this means that a single write timeout or disconnect will
|
||||
# immediately close all pending streams. Without requiring multiple
|
||||
# sequential timeouts.)
|
||||
# 2. Mark the connection as errored, so that we don't accept any other
|
||||
# incoming requests.
|
||||
self._write_exception = exc
|
||||
self._connection_error = True
|
||||
raise exc
|
||||
|
||||
# Flow control...
|
||||
|
||||
async def _wait_for_outgoing_flow(self, request: Request, stream_id: int) -> int:
|
||||
"""
|
||||
Returns the maximum allowable outgoing flow for a given stream.
|
||||
|
||||
If the allowable flow is zero, then waits on the network until
|
||||
WindowUpdated frames have increased the flow rate.
|
||||
https://tools.ietf.org/html/rfc7540#section-6.9
|
||||
"""
|
||||
local_flow: int = self._h2_state.local_flow_control_window(stream_id)
|
||||
max_frame_size: int = self._h2_state.max_outbound_frame_size
|
||||
flow = min(local_flow, max_frame_size)
|
||||
while flow == 0:
|
||||
await self._receive_events(request)
|
||||
local_flow = self._h2_state.local_flow_control_window(stream_id)
|
||||
max_frame_size = self._h2_state.max_outbound_frame_size
|
||||
flow = min(local_flow, max_frame_size)
|
||||
return flow
|
||||
|
||||
# Interface for connection pooling...
|
||||
|
||||
def can_handle_request(self, origin: Origin) -> bool:
|
||||
return origin == self._origin
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return (
|
||||
self._state != HTTPConnectionState.CLOSED
|
||||
and not self._connection_error
|
||||
and not self._used_all_stream_ids
|
||||
and not (
|
||||
self._h2_state.state_machine.state
|
||||
== h2.connection.ConnectionState.CLOSED
|
||||
)
|
||||
)
|
||||
|
||||
def has_expired(self) -> bool:
|
||||
now = time.monotonic()
|
||||
return self._expire_at is not None and now > self._expire_at
|
||||
|
||||
def is_idle(self) -> bool:
|
||||
return self._state == HTTPConnectionState.IDLE
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
return self._state == HTTPConnectionState.CLOSED
|
||||
|
||||
def info(self) -> str:
|
||||
origin = str(self._origin)
|
||||
return (
|
||||
f"{origin!r}, HTTP/2, {self._state.name}, "
|
||||
f"Request Count: {self._request_count}"
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
origin = str(self._origin)
|
||||
return (
|
||||
f"<{class_name} [{origin!r}, {self._state.name}, "
|
||||
f"Request Count: {self._request_count}]>"
|
||||
)
|
||||
|
||||
# These context managers are not used in the standard flow, but are
|
||||
# useful for testing or working with connection instances directly.
|
||||
|
||||
async def __aenter__(self) -> AsyncHTTP2Connection:
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None = None,
|
||||
exc_value: BaseException | None = None,
|
||||
traceback: types.TracebackType | None = None,
|
||||
) -> None:
|
||||
await self.aclose()
|
||||
|
||||
|
||||
class HTTP2ConnectionByteStream:
|
||||
def __init__(
|
||||
self, connection: AsyncHTTP2Connection, request: Request, stream_id: int
|
||||
) -> None:
|
||||
self._connection = connection
|
||||
self._request = request
|
||||
self._stream_id = stream_id
|
||||
self._closed = False
|
||||
|
||||
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
|
||||
kwargs = {"request": self._request, "stream_id": self._stream_id}
|
||||
try:
|
||||
async with Trace("receive_response_body", logger, self._request, kwargs):
|
||||
async for chunk in self._connection._receive_response_body(
|
||||
request=self._request, stream_id=self._stream_id
|
||||
):
|
||||
yield chunk
|
||||
except BaseException as exc:
|
||||
# If we get an exception while streaming the response,
|
||||
# we want to close the response (and possibly the connection)
|
||||
# before raising that exception.
|
||||
with AsyncShieldCancellation():
|
||||
await self.aclose()
|
||||
raise exc
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if not self._closed:
|
||||
self._closed = True
|
||||
kwargs = {"stream_id": self._stream_id}
|
||||
async with Trace("response_closed", logger, self._request, kwargs):
|
||||
await self._connection._response_closed(stream_id=self._stream_id)
|
||||
@@ -0,0 +1,367 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import ssl
|
||||
import typing
|
||||
|
||||
from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend
|
||||
from .._exceptions import ProxyError
|
||||
from .._models import (
|
||||
URL,
|
||||
Origin,
|
||||
Request,
|
||||
Response,
|
||||
enforce_bytes,
|
||||
enforce_headers,
|
||||
enforce_url,
|
||||
)
|
||||
from .._ssl import default_ssl_context
|
||||
from .._synchronization import AsyncLock
|
||||
from .._trace import Trace
|
||||
from .connection import AsyncHTTPConnection
|
||||
from .connection_pool import AsyncConnectionPool
|
||||
from .http11 import AsyncHTTP11Connection
|
||||
from .interfaces import AsyncConnectionInterface
|
||||
|
||||
ByteOrStr = typing.Union[bytes, str]
|
||||
HeadersAsSequence = typing.Sequence[typing.Tuple[ByteOrStr, ByteOrStr]]
|
||||
HeadersAsMapping = typing.Mapping[ByteOrStr, ByteOrStr]
|
||||
|
||||
|
||||
logger = logging.getLogger("httpcore.proxy")
|
||||
|
||||
|
||||
def merge_headers(
|
||||
default_headers: typing.Sequence[tuple[bytes, bytes]] | None = None,
|
||||
override_headers: typing.Sequence[tuple[bytes, bytes]] | None = None,
|
||||
) -> list[tuple[bytes, bytes]]:
|
||||
"""
|
||||
Append default_headers and override_headers, de-duplicating if a key exists
|
||||
in both cases.
|
||||
"""
|
||||
default_headers = [] if default_headers is None else list(default_headers)
|
||||
override_headers = [] if override_headers is None else list(override_headers)
|
||||
has_override = set(key.lower() for key, value in override_headers)
|
||||
default_headers = [
|
||||
(key, value)
|
||||
for key, value in default_headers
|
||||
if key.lower() not in has_override
|
||||
]
|
||||
return default_headers + override_headers
|
||||
|
||||
|
||||
class AsyncHTTPProxy(AsyncConnectionPool): # pragma: nocover
|
||||
"""
|
||||
A connection pool that sends requests via an HTTP proxy.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
proxy_url: URL | bytes | str,
|
||||
proxy_auth: tuple[bytes | str, bytes | str] | None = None,
|
||||
proxy_headers: HeadersAsMapping | HeadersAsSequence | None = None,
|
||||
ssl_context: ssl.SSLContext | None = None,
|
||||
proxy_ssl_context: ssl.SSLContext | None = None,
|
||||
max_connections: int | None = 10,
|
||||
max_keepalive_connections: int | None = None,
|
||||
keepalive_expiry: float | None = None,
|
||||
http1: bool = True,
|
||||
http2: bool = False,
|
||||
retries: int = 0,
|
||||
local_address: str | None = None,
|
||||
uds: str | None = None,
|
||||
network_backend: AsyncNetworkBackend | None = None,
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
A connection pool for making HTTP requests.
|
||||
|
||||
Parameters:
|
||||
proxy_url: The URL to use when connecting to the proxy server.
|
||||
For example `"http://127.0.0.1:8080/"`.
|
||||
proxy_auth: Any proxy authentication as a two-tuple of
|
||||
(username, password). May be either bytes or ascii-only str.
|
||||
proxy_headers: Any HTTP headers to use for the proxy requests.
|
||||
For example `{"Proxy-Authorization": "Basic <username>:<password>"}`.
|
||||
ssl_context: An SSL context to use for verifying connections.
|
||||
If not specified, the default `httpcore.default_ssl_context()`
|
||||
will be used.
|
||||
proxy_ssl_context: The same as `ssl_context`, but for a proxy server rather than a remote origin.
|
||||
max_connections: The maximum number of concurrent HTTP connections that
|
||||
the pool should allow. Any attempt to send a request on a pool that
|
||||
would exceed this amount will block until a connection is available.
|
||||
max_keepalive_connections: The maximum number of idle HTTP connections
|
||||
that will be maintained in the pool.
|
||||
keepalive_expiry: The duration in seconds that an idle HTTP connection
|
||||
may be maintained for before being expired from the pool.
|
||||
http1: A boolean indicating if HTTP/1.1 requests should be supported
|
||||
by the connection pool. Defaults to True.
|
||||
http2: A boolean indicating if HTTP/2 requests should be supported by
|
||||
the connection pool. Defaults to False.
|
||||
retries: The maximum number of retries when trying to establish
|
||||
a connection.
|
||||
local_address: Local address to connect from. Can also be used to
|
||||
connect using a particular address family. Using
|
||||
`local_address="0.0.0.0"` will connect using an `AF_INET` address
|
||||
(IPv4), while using `local_address="::"` will connect using an
|
||||
`AF_INET6` address (IPv6).
|
||||
uds: Path to a Unix Domain Socket to use instead of TCP sockets.
|
||||
network_backend: A backend instance to use for handling network I/O.
|
||||
"""
|
||||
super().__init__(
|
||||
ssl_context=ssl_context,
|
||||
max_connections=max_connections,
|
||||
max_keepalive_connections=max_keepalive_connections,
|
||||
keepalive_expiry=keepalive_expiry,
|
||||
http1=http1,
|
||||
http2=http2,
|
||||
network_backend=network_backend,
|
||||
retries=retries,
|
||||
local_address=local_address,
|
||||
uds=uds,
|
||||
socket_options=socket_options,
|
||||
)
|
||||
|
||||
self._proxy_url = enforce_url(proxy_url, name="proxy_url")
|
||||
if (
|
||||
self._proxy_url.scheme == b"http" and proxy_ssl_context is not None
|
||||
): # pragma: no cover
|
||||
raise RuntimeError(
|
||||
"The `proxy_ssl_context` argument is not allowed for the http scheme"
|
||||
)
|
||||
|
||||
self._ssl_context = ssl_context
|
||||
self._proxy_ssl_context = proxy_ssl_context
|
||||
self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
|
||||
if proxy_auth is not None:
|
||||
username = enforce_bytes(proxy_auth[0], name="proxy_auth")
|
||||
password = enforce_bytes(proxy_auth[1], name="proxy_auth")
|
||||
userpass = username + b":" + password
|
||||
authorization = b"Basic " + base64.b64encode(userpass)
|
||||
self._proxy_headers = [
|
||||
(b"Proxy-Authorization", authorization)
|
||||
] + self._proxy_headers
|
||||
|
||||
def create_connection(self, origin: Origin) -> AsyncConnectionInterface:
|
||||
if origin.scheme == b"http":
|
||||
return AsyncForwardHTTPConnection(
|
||||
proxy_origin=self._proxy_url.origin,
|
||||
proxy_headers=self._proxy_headers,
|
||||
remote_origin=origin,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
network_backend=self._network_backend,
|
||||
proxy_ssl_context=self._proxy_ssl_context,
|
||||
)
|
||||
return AsyncTunnelHTTPConnection(
|
||||
proxy_origin=self._proxy_url.origin,
|
||||
proxy_headers=self._proxy_headers,
|
||||
remote_origin=origin,
|
||||
ssl_context=self._ssl_context,
|
||||
proxy_ssl_context=self._proxy_ssl_context,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
http1=self._http1,
|
||||
http2=self._http2,
|
||||
network_backend=self._network_backend,
|
||||
)
|
||||
|
||||
|
||||
class AsyncForwardHTTPConnection(AsyncConnectionInterface):
|
||||
def __init__(
|
||||
self,
|
||||
proxy_origin: Origin,
|
||||
remote_origin: Origin,
|
||||
proxy_headers: HeadersAsMapping | HeadersAsSequence | None = None,
|
||||
keepalive_expiry: float | None = None,
|
||||
network_backend: AsyncNetworkBackend | None = None,
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
proxy_ssl_context: ssl.SSLContext | None = None,
|
||||
) -> None:
|
||||
self._connection = AsyncHTTPConnection(
|
||||
origin=proxy_origin,
|
||||
keepalive_expiry=keepalive_expiry,
|
||||
network_backend=network_backend,
|
||||
socket_options=socket_options,
|
||||
ssl_context=proxy_ssl_context,
|
||||
)
|
||||
self._proxy_origin = proxy_origin
|
||||
self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
|
||||
self._remote_origin = remote_origin
|
||||
|
||||
async def handle_async_request(self, request: Request) -> Response:
|
||||
headers = merge_headers(self._proxy_headers, request.headers)
|
||||
url = URL(
|
||||
scheme=self._proxy_origin.scheme,
|
||||
host=self._proxy_origin.host,
|
||||
port=self._proxy_origin.port,
|
||||
target=bytes(request.url),
|
||||
)
|
||||
proxy_request = Request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
content=request.stream,
|
||||
extensions=request.extensions,
|
||||
)
|
||||
return await self._connection.handle_async_request(proxy_request)
|
||||
|
||||
def can_handle_request(self, origin: Origin) -> bool:
|
||||
return origin == self._remote_origin
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self._connection.aclose()
|
||||
|
||||
def info(self) -> str:
|
||||
return self._connection.info()
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self._connection.is_available()
|
||||
|
||||
def has_expired(self) -> bool:
|
||||
return self._connection.has_expired()
|
||||
|
||||
def is_idle(self) -> bool:
|
||||
return self._connection.is_idle()
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
return self._connection.is_closed()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__} [{self.info()}]>"
|
||||
|
||||
|
||||
class AsyncTunnelHTTPConnection(AsyncConnectionInterface):
|
||||
def __init__(
|
||||
self,
|
||||
proxy_origin: Origin,
|
||||
remote_origin: Origin,
|
||||
ssl_context: ssl.SSLContext | None = None,
|
||||
proxy_ssl_context: ssl.SSLContext | None = None,
|
||||
proxy_headers: typing.Sequence[tuple[bytes, bytes]] | None = None,
|
||||
keepalive_expiry: float | None = None,
|
||||
http1: bool = True,
|
||||
http2: bool = False,
|
||||
network_backend: AsyncNetworkBackend | None = None,
|
||||
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
|
||||
) -> None:
|
||||
self._connection: AsyncConnectionInterface = AsyncHTTPConnection(
|
||||
origin=proxy_origin,
|
||||
keepalive_expiry=keepalive_expiry,
|
||||
network_backend=network_backend,
|
||||
socket_options=socket_options,
|
||||
ssl_context=proxy_ssl_context,
|
||||
)
|
||||
self._proxy_origin = proxy_origin
|
||||
self._remote_origin = remote_origin
|
||||
self._ssl_context = ssl_context
|
||||
self._proxy_ssl_context = proxy_ssl_context
|
||||
self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers")
|
||||
self._keepalive_expiry = keepalive_expiry
|
||||
self._http1 = http1
|
||||
self._http2 = http2
|
||||
self._connect_lock = AsyncLock()
|
||||
self._connected = False
|
||||
|
||||
async def handle_async_request(self, request: Request) -> Response:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
timeout = timeouts.get("connect", None)
|
||||
|
||||
async with self._connect_lock:
|
||||
if not self._connected:
|
||||
target = b"%b:%d" % (self._remote_origin.host, self._remote_origin.port)
|
||||
|
||||
connect_url = URL(
|
||||
scheme=self._proxy_origin.scheme,
|
||||
host=self._proxy_origin.host,
|
||||
port=self._proxy_origin.port,
|
||||
target=target,
|
||||
)
|
||||
connect_headers = merge_headers(
|
||||
[(b"Host", target), (b"Accept", b"*/*")], self._proxy_headers
|
||||
)
|
||||
connect_request = Request(
|
||||
method=b"CONNECT",
|
||||
url=connect_url,
|
||||
headers=connect_headers,
|
||||
extensions=request.extensions,
|
||||
)
|
||||
connect_response = await self._connection.handle_async_request(
|
||||
connect_request
|
||||
)
|
||||
|
||||
if connect_response.status < 200 or connect_response.status > 299:
|
||||
reason_bytes = connect_response.extensions.get("reason_phrase", b"")
|
||||
reason_str = reason_bytes.decode("ascii", errors="ignore")
|
||||
msg = "%d %s" % (connect_response.status, reason_str)
|
||||
await self._connection.aclose()
|
||||
raise ProxyError(msg)
|
||||
|
||||
stream = connect_response.extensions["network_stream"]
|
||||
|
||||
# Upgrade the stream to SSL
|
||||
ssl_context = (
|
||||
default_ssl_context()
|
||||
if self._ssl_context is None
|
||||
else self._ssl_context
|
||||
)
|
||||
alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"]
|
||||
ssl_context.set_alpn_protocols(alpn_protocols)
|
||||
|
||||
kwargs = {
|
||||
"ssl_context": ssl_context,
|
||||
"server_hostname": self._remote_origin.host.decode("ascii"),
|
||||
"timeout": timeout,
|
||||
}
|
||||
async with Trace("start_tls", logger, request, kwargs) as trace:
|
||||
stream = await stream.start_tls(**kwargs)
|
||||
trace.return_value = stream
|
||||
|
||||
# Determine if we should be using HTTP/1.1 or HTTP/2
|
||||
ssl_object = stream.get_extra_info("ssl_object")
|
||||
http2_negotiated = (
|
||||
ssl_object is not None
|
||||
and ssl_object.selected_alpn_protocol() == "h2"
|
||||
)
|
||||
|
||||
# Create the HTTP/1.1 or HTTP/2 connection
|
||||
if http2_negotiated or (self._http2 and not self._http1):
|
||||
from .http2 import AsyncHTTP2Connection
|
||||
|
||||
self._connection = AsyncHTTP2Connection(
|
||||
origin=self._remote_origin,
|
||||
stream=stream,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
)
|
||||
else:
|
||||
self._connection = AsyncHTTP11Connection(
|
||||
origin=self._remote_origin,
|
||||
stream=stream,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
)
|
||||
|
||||
self._connected = True
|
||||
return await self._connection.handle_async_request(request)
|
||||
|
||||
def can_handle_request(self, origin: Origin) -> bool:
|
||||
return origin == self._remote_origin
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self._connection.aclose()
|
||||
|
||||
def info(self) -> str:
|
||||
return self._connection.info()
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self._connection.is_available()
|
||||
|
||||
def has_expired(self) -> bool:
|
||||
return self._connection.has_expired()
|
||||
|
||||
def is_idle(self) -> bool:
|
||||
return self._connection.is_idle()
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
return self._connection.is_closed()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__} [{self.info()}]>"
|
||||
@@ -0,0 +1,137 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import typing
|
||||
|
||||
from .._models import (
|
||||
URL,
|
||||
Extensions,
|
||||
HeaderTypes,
|
||||
Origin,
|
||||
Request,
|
||||
Response,
|
||||
enforce_bytes,
|
||||
enforce_headers,
|
||||
enforce_url,
|
||||
include_request_headers,
|
||||
)
|
||||
|
||||
|
||||
class AsyncRequestInterface:
|
||||
async def request(
|
||||
self,
|
||||
method: bytes | str,
|
||||
url: URL | bytes | str,
|
||||
*,
|
||||
headers: HeaderTypes = None,
|
||||
content: bytes | typing.AsyncIterator[bytes] | None = None,
|
||||
extensions: Extensions | None = None,
|
||||
) -> Response:
|
||||
# Strict type checking on our parameters.
|
||||
method = enforce_bytes(method, name="method")
|
||||
url = enforce_url(url, name="url")
|
||||
headers = enforce_headers(headers, name="headers")
|
||||
|
||||
# Include Host header, and optionally Content-Length or Transfer-Encoding.
|
||||
headers = include_request_headers(headers, url=url, content=content)
|
||||
|
||||
request = Request(
|
||||
method=method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
content=content,
|
||||
extensions=extensions,
|
||||
)
|
||||
response = await self.handle_async_request(request)
|
||||
try:
|
||||
await response.aread()
|
||||
finally:
|
||||
await response.aclose()
|
||||
return response
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def stream(
|
||||
self,
|
||||
method: bytes | str,
|
||||
url: URL | bytes | str,
|
||||
*,
|
||||
headers: HeaderTypes = None,
|
||||
content: bytes | typing.AsyncIterator[bytes] | None = None,
|
||||
extensions: Extensions | None = None,
|
||||
) -> typing.AsyncIterator[Response]:
|
||||
# Strict type checking on our parameters.
|
||||
method = enforce_bytes(method, name="method")
|
||||
url = enforce_url(url, name="url")
|
||||
headers = enforce_headers(headers, name="headers")
|
||||
|
||||
# Include Host header, and optionally Content-Length or Transfer-Encoding.
|
||||
headers = include_request_headers(headers, url=url, content=content)
|
||||
|
||||
request = Request(
|
||||
method=method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
content=content,
|
||||
extensions=extensions,
|
||||
)
|
||||
response = await self.handle_async_request(request)
|
||||
try:
|
||||
yield response
|
||||
finally:
|
||||
await response.aclose()
|
||||
|
||||
async def handle_async_request(self, request: Request) -> Response:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
|
||||
class AsyncConnectionInterface(AsyncRequestInterface):
|
||||
async def aclose(self) -> None:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
def info(self) -> str:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
def can_handle_request(self, origin: Origin) -> bool:
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""
|
||||
Return `True` if the connection is currently able to accept an
|
||||
outgoing request.
|
||||
|
||||
An HTTP/1.1 connection will only be available if it is currently idle.
|
||||
|
||||
An HTTP/2 connection will be available so long as the stream ID space is
|
||||
not yet exhausted, and the connection is not in an error state.
|
||||
|
||||
While the connection is being established we may not yet know if it is going
|
||||
to result in an HTTP/1.1 or HTTP/2 connection. The connection should be
|
||||
treated as being available, but might ultimately raise `NewConnectionRequired`
|
||||
required exceptions if multiple requests are attempted over a connection
|
||||
that ends up being established as HTTP/1.1.
|
||||
"""
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
def has_expired(self) -> bool:
|
||||
"""
|
||||
Return `True` if the connection is in a state where it should be closed.
|
||||
|
||||
This either means that the connection is idle and it has passed the
|
||||
expiry time on its keep-alive, or that server has sent an EOF.
|
||||
"""
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
def is_idle(self) -> bool:
|
||||
"""
|
||||
Return `True` if the connection is currently idle.
|
||||
"""
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
"""
|
||||
Return `True` if the connection has been closed.
|
||||
|
||||
Used when a response is closed to determine if the connection may be
|
||||
returned to the connection pool or not.
|
||||
"""
|
||||
raise NotImplementedError() # pragma: nocover
|
||||
@@ -0,0 +1,341 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import ssl
|
||||
|
||||
import socksio
|
||||
|
||||
from .._backends.auto import AutoBackend
|
||||
from .._backends.base import AsyncNetworkBackend, AsyncNetworkStream
|
||||
from .._exceptions import ConnectionNotAvailable, ProxyError
|
||||
from .._models import URL, Origin, Request, Response, enforce_bytes, enforce_url
|
||||
from .._ssl import default_ssl_context
|
||||
from .._synchronization import AsyncLock
|
||||
from .._trace import Trace
|
||||
from .connection_pool import AsyncConnectionPool
|
||||
from .http11 import AsyncHTTP11Connection
|
||||
from .interfaces import AsyncConnectionInterface
|
||||
|
||||
logger = logging.getLogger("httpcore.socks")
|
||||
|
||||
|
||||
AUTH_METHODS = {
|
||||
b"\x00": "NO AUTHENTICATION REQUIRED",
|
||||
b"\x01": "GSSAPI",
|
||||
b"\x02": "USERNAME/PASSWORD",
|
||||
b"\xff": "NO ACCEPTABLE METHODS",
|
||||
}
|
||||
|
||||
REPLY_CODES = {
|
||||
b"\x00": "Succeeded",
|
||||
b"\x01": "General SOCKS server failure",
|
||||
b"\x02": "Connection not allowed by ruleset",
|
||||
b"\x03": "Network unreachable",
|
||||
b"\x04": "Host unreachable",
|
||||
b"\x05": "Connection refused",
|
||||
b"\x06": "TTL expired",
|
||||
b"\x07": "Command not supported",
|
||||
b"\x08": "Address type not supported",
|
||||
}
|
||||
|
||||
|
||||
async def _init_socks5_connection(
|
||||
stream: AsyncNetworkStream,
|
||||
*,
|
||||
host: bytes,
|
||||
port: int,
|
||||
auth: tuple[bytes, bytes] | None = None,
|
||||
) -> None:
|
||||
conn = socksio.socks5.SOCKS5Connection()
|
||||
|
||||
# Auth method request
|
||||
auth_method = (
|
||||
socksio.socks5.SOCKS5AuthMethod.NO_AUTH_REQUIRED
|
||||
if auth is None
|
||||
else socksio.socks5.SOCKS5AuthMethod.USERNAME_PASSWORD
|
||||
)
|
||||
conn.send(socksio.socks5.SOCKS5AuthMethodsRequest([auth_method]))
|
||||
outgoing_bytes = conn.data_to_send()
|
||||
await stream.write(outgoing_bytes)
|
||||
|
||||
# Auth method response
|
||||
incoming_bytes = await stream.read(max_bytes=4096)
|
||||
response = conn.receive_data(incoming_bytes)
|
||||
assert isinstance(response, socksio.socks5.SOCKS5AuthReply)
|
||||
if response.method != auth_method:
|
||||
requested = AUTH_METHODS.get(auth_method, "UNKNOWN")
|
||||
responded = AUTH_METHODS.get(response.method, "UNKNOWN")
|
||||
raise ProxyError(
|
||||
f"Requested {requested} from proxy server, but got {responded}."
|
||||
)
|
||||
|
||||
if response.method == socksio.socks5.SOCKS5AuthMethod.USERNAME_PASSWORD:
|
||||
# Username/password request
|
||||
assert auth is not None
|
||||
username, password = auth
|
||||
conn.send(socksio.socks5.SOCKS5UsernamePasswordRequest(username, password))
|
||||
outgoing_bytes = conn.data_to_send()
|
||||
await stream.write(outgoing_bytes)
|
||||
|
||||
# Username/password response
|
||||
incoming_bytes = await stream.read(max_bytes=4096)
|
||||
response = conn.receive_data(incoming_bytes)
|
||||
assert isinstance(response, socksio.socks5.SOCKS5UsernamePasswordReply)
|
||||
if not response.success:
|
||||
raise ProxyError("Invalid username/password")
|
||||
|
||||
# Connect request
|
||||
conn.send(
|
||||
socksio.socks5.SOCKS5CommandRequest.from_address(
|
||||
socksio.socks5.SOCKS5Command.CONNECT, (host, port)
|
||||
)
|
||||
)
|
||||
outgoing_bytes = conn.data_to_send()
|
||||
await stream.write(outgoing_bytes)
|
||||
|
||||
# Connect response
|
||||
incoming_bytes = await stream.read(max_bytes=4096)
|
||||
response = conn.receive_data(incoming_bytes)
|
||||
assert isinstance(response, socksio.socks5.SOCKS5Reply)
|
||||
if response.reply_code != socksio.socks5.SOCKS5ReplyCode.SUCCEEDED:
|
||||
reply_code = REPLY_CODES.get(response.reply_code, "UNKOWN")
|
||||
raise ProxyError(f"Proxy Server could not connect: {reply_code}.")
|
||||
|
||||
|
||||
class AsyncSOCKSProxy(AsyncConnectionPool): # pragma: nocover
|
||||
"""
|
||||
A connection pool that sends requests via an HTTP proxy.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
proxy_url: URL | bytes | str,
|
||||
proxy_auth: tuple[bytes | str, bytes | str] | None = None,
|
||||
ssl_context: ssl.SSLContext | None = None,
|
||||
max_connections: int | None = 10,
|
||||
max_keepalive_connections: int | None = None,
|
||||
keepalive_expiry: float | None = None,
|
||||
http1: bool = True,
|
||||
http2: bool = False,
|
||||
retries: int = 0,
|
||||
network_backend: AsyncNetworkBackend | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
A connection pool for making HTTP requests.
|
||||
|
||||
Parameters:
|
||||
proxy_url: The URL to use when connecting to the proxy server.
|
||||
For example `"http://127.0.0.1:8080/"`.
|
||||
ssl_context: An SSL context to use for verifying connections.
|
||||
If not specified, the default `httpcore.default_ssl_context()`
|
||||
will be used.
|
||||
max_connections: The maximum number of concurrent HTTP connections that
|
||||
the pool should allow. Any attempt to send a request on a pool that
|
||||
would exceed this amount will block until a connection is available.
|
||||
max_keepalive_connections: The maximum number of idle HTTP connections
|
||||
that will be maintained in the pool.
|
||||
keepalive_expiry: The duration in seconds that an idle HTTP connection
|
||||
may be maintained for before being expired from the pool.
|
||||
http1: A boolean indicating if HTTP/1.1 requests should be supported
|
||||
by the connection pool. Defaults to True.
|
||||
http2: A boolean indicating if HTTP/2 requests should be supported by
|
||||
the connection pool. Defaults to False.
|
||||
retries: The maximum number of retries when trying to establish
|
||||
a connection.
|
||||
local_address: Local address to connect from. Can also be used to
|
||||
connect using a particular address family. Using
|
||||
`local_address="0.0.0.0"` will connect using an `AF_INET` address
|
||||
(IPv4), while using `local_address="::"` will connect using an
|
||||
`AF_INET6` address (IPv6).
|
||||
uds: Path to a Unix Domain Socket to use instead of TCP sockets.
|
||||
network_backend: A backend instance to use for handling network I/O.
|
||||
"""
|
||||
super().__init__(
|
||||
ssl_context=ssl_context,
|
||||
max_connections=max_connections,
|
||||
max_keepalive_connections=max_keepalive_connections,
|
||||
keepalive_expiry=keepalive_expiry,
|
||||
http1=http1,
|
||||
http2=http2,
|
||||
network_backend=network_backend,
|
||||
retries=retries,
|
||||
)
|
||||
self._ssl_context = ssl_context
|
||||
self._proxy_url = enforce_url(proxy_url, name="proxy_url")
|
||||
if proxy_auth is not None:
|
||||
username, password = proxy_auth
|
||||
username_bytes = enforce_bytes(username, name="proxy_auth")
|
||||
password_bytes = enforce_bytes(password, name="proxy_auth")
|
||||
self._proxy_auth: tuple[bytes, bytes] | None = (
|
||||
username_bytes,
|
||||
password_bytes,
|
||||
)
|
||||
else:
|
||||
self._proxy_auth = None
|
||||
|
||||
def create_connection(self, origin: Origin) -> AsyncConnectionInterface:
|
||||
return AsyncSocks5Connection(
|
||||
proxy_origin=self._proxy_url.origin,
|
||||
remote_origin=origin,
|
||||
proxy_auth=self._proxy_auth,
|
||||
ssl_context=self._ssl_context,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
http1=self._http1,
|
||||
http2=self._http2,
|
||||
network_backend=self._network_backend,
|
||||
)
|
||||
|
||||
|
||||
class AsyncSocks5Connection(AsyncConnectionInterface):
|
||||
def __init__(
|
||||
self,
|
||||
proxy_origin: Origin,
|
||||
remote_origin: Origin,
|
||||
proxy_auth: tuple[bytes, bytes] | None = None,
|
||||
ssl_context: ssl.SSLContext | None = None,
|
||||
keepalive_expiry: float | None = None,
|
||||
http1: bool = True,
|
||||
http2: bool = False,
|
||||
network_backend: AsyncNetworkBackend | None = None,
|
||||
) -> None:
|
||||
self._proxy_origin = proxy_origin
|
||||
self._remote_origin = remote_origin
|
||||
self._proxy_auth = proxy_auth
|
||||
self._ssl_context = ssl_context
|
||||
self._keepalive_expiry = keepalive_expiry
|
||||
self._http1 = http1
|
||||
self._http2 = http2
|
||||
|
||||
self._network_backend: AsyncNetworkBackend = (
|
||||
AutoBackend() if network_backend is None else network_backend
|
||||
)
|
||||
self._connect_lock = AsyncLock()
|
||||
self._connection: AsyncConnectionInterface | None = None
|
||||
self._connect_failed = False
|
||||
|
||||
async def handle_async_request(self, request: Request) -> Response:
|
||||
timeouts = request.extensions.get("timeout", {})
|
||||
sni_hostname = request.extensions.get("sni_hostname", None)
|
||||
timeout = timeouts.get("connect", None)
|
||||
|
||||
async with self._connect_lock:
|
||||
if self._connection is None:
|
||||
try:
|
||||
# Connect to the proxy
|
||||
kwargs = {
|
||||
"host": self._proxy_origin.host.decode("ascii"),
|
||||
"port": self._proxy_origin.port,
|
||||
"timeout": timeout,
|
||||
}
|
||||
async with Trace("connect_tcp", logger, request, kwargs) as trace:
|
||||
stream = await self._network_backend.connect_tcp(**kwargs)
|
||||
trace.return_value = stream
|
||||
|
||||
# Connect to the remote host using socks5
|
||||
kwargs = {
|
||||
"stream": stream,
|
||||
"host": self._remote_origin.host.decode("ascii"),
|
||||
"port": self._remote_origin.port,
|
||||
"auth": self._proxy_auth,
|
||||
}
|
||||
async with Trace(
|
||||
"setup_socks5_connection", logger, request, kwargs
|
||||
) as trace:
|
||||
await _init_socks5_connection(**kwargs)
|
||||
trace.return_value = stream
|
||||
|
||||
# Upgrade the stream to SSL
|
||||
if self._remote_origin.scheme == b"https":
|
||||
ssl_context = (
|
||||
default_ssl_context()
|
||||
if self._ssl_context is None
|
||||
else self._ssl_context
|
||||
)
|
||||
alpn_protocols = (
|
||||
["http/1.1", "h2"] if self._http2 else ["http/1.1"]
|
||||
)
|
||||
ssl_context.set_alpn_protocols(alpn_protocols)
|
||||
|
||||
kwargs = {
|
||||
"ssl_context": ssl_context,
|
||||
"server_hostname": sni_hostname
|
||||
or self._remote_origin.host.decode("ascii"),
|
||||
"timeout": timeout,
|
||||
}
|
||||
async with Trace("start_tls", logger, request, kwargs) as trace:
|
||||
stream = await stream.start_tls(**kwargs)
|
||||
trace.return_value = stream
|
||||
|
||||
# Determine if we should be using HTTP/1.1 or HTTP/2
|
||||
ssl_object = stream.get_extra_info("ssl_object")
|
||||
http2_negotiated = (
|
||||
ssl_object is not None
|
||||
and ssl_object.selected_alpn_protocol() == "h2"
|
||||
)
|
||||
|
||||
# Create the HTTP/1.1 or HTTP/2 connection
|
||||
if http2_negotiated or (
|
||||
self._http2 and not self._http1
|
||||
): # pragma: nocover
|
||||
from .http2 import AsyncHTTP2Connection
|
||||
|
||||
self._connection = AsyncHTTP2Connection(
|
||||
origin=self._remote_origin,
|
||||
stream=stream,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
)
|
||||
else:
|
||||
self._connection = AsyncHTTP11Connection(
|
||||
origin=self._remote_origin,
|
||||
stream=stream,
|
||||
keepalive_expiry=self._keepalive_expiry,
|
||||
)
|
||||
except Exception as exc:
|
||||
self._connect_failed = True
|
||||
raise exc
|
||||
elif not self._connection.is_available(): # pragma: nocover
|
||||
raise ConnectionNotAvailable()
|
||||
|
||||
return await self._connection.handle_async_request(request)
|
||||
|
||||
def can_handle_request(self, origin: Origin) -> bool:
|
||||
return origin == self._remote_origin
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if self._connection is not None:
|
||||
await self._connection.aclose()
|
||||
|
||||
def is_available(self) -> bool:
|
||||
if self._connection is None: # pragma: nocover
|
||||
# If HTTP/2 support is enabled, and the resulting connection could
|
||||
# end up as HTTP/2 then we should indicate the connection as being
|
||||
# available to service multiple requests.
|
||||
return (
|
||||
self._http2
|
||||
and (self._remote_origin.scheme == b"https" or not self._http1)
|
||||
and not self._connect_failed
|
||||
)
|
||||
return self._connection.is_available()
|
||||
|
||||
def has_expired(self) -> bool:
|
||||
if self._connection is None: # pragma: nocover
|
||||
return self._connect_failed
|
||||
return self._connection.has_expired()
|
||||
|
||||
def is_idle(self) -> bool:
|
||||
if self._connection is None: # pragma: nocover
|
||||
return self._connect_failed
|
||||
return self._connection.is_idle()
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
if self._connection is None: # pragma: nocover
|
||||
return self._connect_failed
|
||||
return self._connection.is_closed()
|
||||
|
||||
def info(self) -> str:
|
||||
if self._connection is None: # pragma: nocover
|
||||
return "CONNECTION FAILED" if self._connect_failed else "CONNECTING"
|
||||
return self._connection.info()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__} [{self.info()}]>"
|
||||
Reference in New Issue
Block a user