Source code for aiosmtplib.connection

"""
Handles client connection/disconnection.
"""
import asyncio
import socket
import ssl
from typing import Any, Optional, Type, Union  # NOQA

from .default import Default, _default
from .errors import (
    SMTPConnectError,
    SMTPResponseException,
    SMTPServerDisconnected,
    SMTPTimeoutError,
)
from .protocol import SMTPProtocol
from .response import SMTPResponse
from .status import SMTPStatus


__all__ = ("SMTPConnection",)


MAX_LINE_LENGTH = 8192
SMTP_PORT = 25
SMTP_TLS_PORT = 465
DEFAULT_TIMEOUT = 60


DefaultNumType = Union[float, int, Default]
DefaultStrType = Union[str, Default]
DefaultSSLContextType = Union[ssl.SSLContext, Default]
NumType = Union[float, int]


class SMTPConnection:
    """
    Handles connection/disconnection from the SMTP server provided.

    Keyword arguments can be provided either on :meth:`__init__` or when
    calling the :meth:`connect` method. Note that in both cases these options
    are saved for later use; subsequent calls to :meth:`connect` will use the
    same options, unless new ones are provided.
    """

    def __init__(
        self,
        hostname: str = "",
        port: int = None,
        source_address: str = None,
        timeout: NumType = DEFAULT_TIMEOUT,
        loop: asyncio.AbstractEventLoop = None,
        use_tls: bool = False,
        validate_certs: bool = True,
        client_cert: str = None,
        client_key: str = None,
        tls_context: ssl.SSLContext = None,
        cert_bundle: str = None,
    ) -> None:
        """
        :keyword hostname:  Server name (or IP) to connect to
        :keyword port: Server port. Defaults to ``25`` if ``use_tls`` is
            ``False``, ``465`` if ``use_tls`` is ``True``.
        :keyword source_address: The hostname of the client. Defaults to the
            result of :func:`socket.getfqdn`. Note that this call blocks.
        :keyword timeout: Default timeout value for the connection, in seconds.
            Defaults to 60.
        :keyword loop: event loop  to run on. If not set, uses
            :func:`asyncio.get_event_loop()`.
        :keyword use_tls: If True, make the initial connection to the server
            over TLS/SSL. Note that if the server supports STARTTLS only, this
            should be False.
        :keyword validate_certs: Determines if server certificates are
            validated. Defaults to True.
        :keyword client_cert: Path to client side certificate, for TLS
            verification.
        :keyword client_key: Path to client side key, for TLS verification.
        :keyword tls_context: An existing :class:`ssl.SSLContext`, for TLS
            verification. Mutually exclusive with ``client_cert``/
            ``client_key``.
        :keyword cert_bundle: Path to certificate bundle, for TLS verification.

        :raises ValueError: mutually exclusive options provided
        """
        self.protocol = None  # type: Optional[SMTPProtocol]
        self.transport = None  # type: Optional[asyncio.BaseTransport]

        if tls_context is not None and client_cert is not None:
            raise ValueError(
                "Either a TLS context or a certificate/key must be provided"
            )

        # Kwarg defaults are provided here, and saved for connect.
        self.hostname = hostname
        self.port = port
        self.timeout = timeout
        self.use_tls = use_tls
        self._source_address = source_address
        self.validate_certs = validate_certs
        self.client_cert = client_cert
        self.client_key = client_key
        self.tls_context = tls_context
        self.cert_bundle = cert_bundle

        self.loop = loop or asyncio.get_event_loop()
        self._connect_lock = asyncio.Lock(loop=self.loop)

    async def __aenter__(self) -> "SMTPConnection":
        if not self.is_connected:
            await self.connect()

        return self

    async def __aexit__(
        self, exc_type: Type[Exception], exc: Exception, traceback: Any
    ) -> None:
        is_connection_error = exc_type in (ConnectionError, SMTPTimeoutError)
        if is_connection_error or not self.is_connected:
            self.close()
        else:
            try:
                await self.quit()
            except (ConnectionError, SMTPResponseException, SMTPTimeoutError):
                self.close()

    @property
    def is_connected(self) -> bool:
        """
        Check if our transport is still connected.
        """
        return bool(self.transport and not self.transport.is_closing())

    @property
    def source_address(self) -> str:
        """
        Get the system hostname to be sent to the SMTP server.
        Simply caches the result of :func:`socket.getfqdn`.
        """
        if self._source_address is None:
            self._source_address = socket.getfqdn()

        return self._source_address

    async def connect(
        self,
        hostname: str = None,
        port: int = None,
        source_address: DefaultStrType = _default,
        timeout: DefaultNumType = _default,
        loop: asyncio.AbstractEventLoop = None,
        use_tls: bool = None,
        validate_certs: bool = None,
        client_cert: DefaultStrType = _default,
        client_key: DefaultStrType = _default,
        tls_context: DefaultSSLContextType = _default,
        cert_bundle: DefaultStrType = _default,
    ) -> SMTPResponse:
        """
        Initialize a connection to the server. Options provided to
        :meth:`.connect` take precedence over those used to initialize the
        class.

        :keyword hostname:  Server name (or IP) to connect to
        :keyword port: Server port. Defaults to 25 if ``use_tls`` is
            False, 465 if ``use_tls`` is True.
        :keyword source_address: The hostname of the client. Defaults to the
            result of :func:`socket.getfqdn`. Note that this call blocks.
        :keyword timeout: Default timeout value for the connection, in seconds.
            Defaults to 60.
        :keyword loop: event loop to run on. If not set, uses
            :func:`asyncio.get_event_loop()`.
        :keyword use_tls: If True, make the initial connection to the server
            over TLS/SSL. Note that if the server supports STARTTLS only, this
            should be False.
        :keyword validate_certs: Determines if server certificates are
            validated. Defaults to True.
        :keyword client_cert: Path to client side certificate, for TLS.
        :keyword client_key: Path to client side key, for TLS.
        :keyword tls_context: An existing :class:`ssl.SSLContext`, for TLS.
            Mutually exclusive with ``client_cert``/``client_key``.
        :keyword cert_bundle: Path to certificate bundle, for TLS verification.

        :raises ValueError: mutually exclusive options provided
        """
        await self._connect_lock.acquire()

        if hostname is not None:
            self.hostname = hostname
        if loop is not None:
            self.loop = loop
        if use_tls is not None:
            self.use_tls = use_tls
        if validate_certs is not None:
            self.validate_certs = validate_certs

        if port is not None:
            self.port = port

        if self.port is None:
            self.port = SMTP_TLS_PORT if self.use_tls else SMTP_PORT

        if timeout is not _default:
            self.timeout = timeout  # type: ignore
        if source_address is not _default:
            self._source_address = source_address  # type: ignore
        if client_cert is not _default:
            self.client_cert = client_cert  # type: ignore
        if client_key is not _default:
            self.client_key = client_key  # type: ignore
        if tls_context is not _default:
            self.tls_context = tls_context  # type: ignore
        if cert_bundle is not _default:
            self.cert_bundle = cert_bundle  # type: ignore

        if self.tls_context is not None and self.client_cert is not None:
            raise ValueError(
                "Either a TLS context or a certificate/key must be provided"
            )

        response = await self._create_connection()

        return response

    async def _create_connection(self) -> SMTPResponse:
        assert self.hostname is not None, "Hostname must be set"
        assert self.port is not None, "Port must be set"

        reader = asyncio.StreamReader(limit=MAX_LINE_LENGTH, loop=self.loop)
        protocol = SMTPProtocol(reader, loop=self.loop)

        tls_context = None  # type: Optional[ssl.SSLContext]
        if self.use_tls:
            tls_context = self._get_tls_context()

        connect_future = self.loop.create_connection(
            lambda: protocol, host=self.hostname, port=self.port, ssl=tls_context
        )
        try:
            transport, _ = await asyncio.wait_for(
                connect_future, timeout=self.timeout, loop=self.loop
            )
        except (ConnectionRefusedError, OSError) as err:
            self.close()
            raise SMTPConnectError(
                "Error connecting to {host} on port {port}: {err}".format(
                    host=self.hostname, port=self.port, err=err
                )
            )
        except asyncio.TimeoutError as exc:
            self.close()
            raise SMTPTimeoutError(str(exc))

        self.protocol = protocol
        self.transport = transport

        waiter = asyncio.Task(protocol.read_response(), loop=self.loop)

        try:
            response = await asyncio.wait_for(
                waiter, timeout=self.timeout, loop=self.loop
            )
        except asyncio.TimeoutError as exc:
            self.close()
            raise SMTPTimeoutError(str(exc))

        if response.code != SMTPStatus.ready:
            self.close()
            raise SMTPConnectError(str(response))

        return response

    async def execute_command(
        self, *args: bytes, timeout: DefaultNumType = _default
    ) -> SMTPResponse:
        """
        Check that we're connected, if we got a timeout value, and then
        pass the command to the protocol.

        :raises SMTPServerDisconnected: connection lost
        """
        if timeout is _default:
            timeout = self.timeout  # type: ignore

        self._raise_error_if_disconnected()

        try:
            response = await self.protocol.execute_command(  # type: ignore
                *args, timeout=timeout
            )
        except SMTPServerDisconnected:
            # On disconnect, clean up the connection.
            self.close()
            raise

        # If the server is unavailable, be nice and close the connection
        if response.code == SMTPStatus.domain_unavailable:
            self.close()

        return response

    async def quit(self, timeout: DefaultNumType = _default) -> SMTPResponse:
        raise NotImplementedError

    def _get_tls_context(self) -> ssl.SSLContext:
        """
        Build an SSLContext object from the options we've been given.
        """
        if self.tls_context is not None:
            context = self.tls_context
        else:
            # SERVER_AUTH is what we want for a client side socket
            context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
            context.check_hostname = bool(self.validate_certs)
            if self.validate_certs:
                context.verify_mode = ssl.CERT_REQUIRED
            else:
                context.verify_mode = ssl.CERT_NONE

            if self.cert_bundle is not None:
                context.load_verify_locations(cafile=self.cert_bundle)

            if self.client_cert is not None:
                context.load_cert_chain(self.client_cert, keyfile=self.client_key)

        return context

    def _raise_error_if_disconnected(self) -> None:
        """
        See if we're still connected, and if not, raise
        ``SMTPServerDisconnected``.
        """
        if (
            self.transport is None
            or self.protocol is None
            or self.transport.is_closing()
        ):
            self.close()
            raise SMTPServerDisconnected("Disconnected from SMTP server")

    def close(self) -> None:
        """
        Closes the connection.
        """
        if self.transport is not None and not self.transport.is_closing():
            self.transport.close()

        if self._connect_lock.locked():
            self._connect_lock.release()

        self.protocol = None
        self.transport = None

    def get_transport_info(self, key: str) -> Any:
        """
        Get extra info from the transport.
        Supported keys:

            - ``peername``
            - ``socket``
            - ``sockname``
            - ``compression``
            - ``cipher``
            - ``peercert``
            - ``sslcontext``
            - ``sslobject``

        :raises SMTPServerDisconnected: connection lost
        """
        self._raise_error_if_disconnected()

        return self.transport.get_extra_info(key)  # type: ignore