Source code for yadisk.sessions.aiohttp_session

# -*- coding: utf-8 -*-

from typing import Optional, Any, Union

from ..common import CaseInsensitiveDict

from ..exceptions import (
    RequestError, TooManyRedirectsError,
    RequestTimeoutError, YaDiskConnectionError
)

from ..async_session import AsyncSession, AsyncResponse
from ..compat import Iterable
from ..common import is_async_func
from ..types import JSON, AsyncConsumeCallback, Headers, TimeoutParameter, HTTPMethod

import aiohttp
import sys

__all__ = ["AIOHTTPSession"]

def convert_aiohttp_exception(exc: aiohttp.ClientError) -> Union[RequestError, aiohttp.ClientError]:
    if isinstance(exc, aiohttp.TooManyRedirects):
        return TooManyRedirectsError(str(exc))
    elif isinstance(exc, aiohttp.ServerTimeoutError):
        return RequestTimeoutError(str(exc))
    elif isinstance(exc, aiohttp.ClientConnectionError):
        return YaDiskConnectionError(str(exc))
    elif isinstance(exc, aiohttp.ClientError):
        return RequestError(str(exc))
    else:
        return exc

class AIOHTTPResponse(AsyncResponse):
    def __init__(self, response: aiohttp.ClientResponse):
        self._response = response
        self.status = response.status

    async def json(self) -> JSON:
        return await self._response.json()

    async def download(self, consume_callback: AsyncConsumeCallback) -> None:
        callback: Any = consume_callback

        try:
            if is_async_func(consume_callback):
                async for chunk in self._response.content.iter_chunked(8192):
                    await callback(chunk)
            else:
                async for chunk in self._response.content.iter_chunked(8192):
                    callback(chunk)
        except aiohttp.ClientError as e:
            raise convert_aiohttp_exception(e)

    async def close(self) -> None:
        await self._response.release()

def convert_timeout(timeout: TimeoutParameter) -> Optional[aiohttp.ClientTimeout]:
    if timeout is None:
        return None

    if isinstance(timeout, (int, float)):
        return aiohttp.ClientTimeout(sock_connect=timeout, sock_read=timeout)

    connect, read = timeout

    return aiohttp.ClientTimeout(sock_connect=connect, sock_read=read)

DEFAULT_USER_AGENT = "Python/%s.%s aiohttp/%s" % (sys.version_info.major,
                                                  sys.version_info.minor,
                                                  aiohttp.__version__)

[docs] class AIOHTTPSession(AsyncSession): """ .. _aiohttp: https://pypi.org/project/aiohttp :any:`AsyncSession` implementation using the `aiohttp`_ library. All arguments passed in the constructor are directly forwared to :any:`aiohttp.ClientSession`. :ivar aiohttp_session: underlying instance of :any:`aiohttp.ClientSession` To pass `aiohttp`-specific arguments from :any:`AsyncClient` use :code:`aiohttp_args` keyword argument. Usage example: .. code:: python import yadisk async def main(): async with yadisk.AsyncClient(..., session="aiohttp") as client: await client.get_meta( "/my_file.txt", n_retries=5, aiohttp_args={ "proxies": {"https": "http://example.com:1234"}, "verify": False } ) """ def __init__(self, *args, **kwargs): headers = CaseInsensitiveDict({ "User-Agent": DEFAULT_USER_AGENT, "Accept-Encoding": "gzip, deflate", "Accept": "*/*", "Connection": "keep-alive" }) headers.update(kwargs.get("headers", {})) kwargs["headers"] = headers self._session = aiohttp.ClientSession(*args, **kwargs) @property def aiohttp_session(self) -> aiohttp.ClientSession: return self._session def set_headers(self, headers: Headers) -> None: self._session.headers.update(headers) def remove_headers(self, headers: Iterable[str]) -> None: for h in headers: self._session.headers.pop(h, None) async def send_request(self, method: HTTPMethod, url: str, **kwargs) -> AsyncResponse: if "timeout" in kwargs: kwargs["timeout"] = convert_timeout(kwargs["timeout"]) kwargs.pop("stream", None) if "aiohttp_args" in kwargs: kwargs.update(kwargs.pop("aiohttp_args")) try: return AIOHTTPResponse(await self._session.request(method, url, **kwargs)) except aiohttp.ClientError as e: raise convert_aiohttp_exception(e) async def close(self) -> None: await self._session.close()