from __future__ import annotations

from collections.abc import Callable, Collection, Iterable, Mapping, Sequence
from pathlib import Path
from typing import (
    IO,
    TYPE_CHECKING,
    Any,
    Literal,
    Protocol,
    TypedDict,
    TypeVar,
    Union,
)

if TYPE_CHECKING:
    from datetime import date, datetime, time, timedelta
    from decimal import Decimal
    from typing import TypeAlias

    from sqlalchemy.engine import Connection, Engine
    from sqlalchemy.ext.asyncio import (
        AsyncConnection,
        AsyncEngine,
        AsyncSession,
        async_sessionmaker,
    )
    from sqlalchemy.orm import Session

    from polars import DataFrame, Expr, LazyFrame, Series
    from polars._dependencies import numpy as np
    from polars.datatypes import DataType, DataTypeClass, IntegerType, TemporalType
    from polars.lazyframe.engine_config import GPUEngine
    from polars.selectors import Selector


class ArrowArrayExportable(Protocol):
    """Type protocol for Arrow C Data Interface via Arrow PyCapsule Interface."""

    def __arrow_c_array__(
        self, requested_schema: object | None = None
    ) -> tuple[object, object]: ...


class ArrowStreamExportable(Protocol):
    """Type protocol for Arrow C Stream Interface via Arrow PyCapsule Interface."""

    def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: ...


class ArrowSchemaExportable(Protocol):
    """Type protocol for Arrow C Schema Interface via Arrow PyCapsule Interface."""

    def __arrow_c_schema__(self) -> object: ...


class NumpyArray(Protocol):
    """Protocol to match NumPy Arrays without needing NumPy installed."""

    def byteswap(self, *args: Any, **kwargs: Any) -> Any: ...
    def conjugate(self, *args: Any, **kwargs: Any) -> Any: ...
    def ravel(self, *args: Any, **kwargs: Any) -> Any: ...
    def searchsorted(self, *args: Any, **kwargs: Any) -> Any: ...
    def swapaxes(self, *args: Any, **kwargs: Any) -> Any: ...


class PyArrowArray(Protocol):
    """
    Protocol to match PyArrow arrays without needing PyArrow installed.

    Only use for function arguments, not return types.
    """

    def buffers(self, *args: Any, **kwargs: Any) -> Any: ...
    def tolist(self, *args: Any, **kwargs: Any) -> Any: ...


class PyArrowChunkedArray(Protocol):
    """
    Protocol to match PyArrow chunked arrays without needing PyArrow installed.

    Only use for function arguments, not return types.
    """

    def iterchunks(self, *args: Any, **kwargs: Any) -> Any: ...


class PyArrowTable(Protocol):
    """
    Protocol to match PyArrow tables without needing PyArrow installed.

    Only use for function arguments, not return types.
    """

    def filter(self, *args: Any, **kwargs: Any) -> Any: ...
    def group_by(self, *args: Any, **kwargs: Any) -> Any: ...
    def add_column(self, *args: Any, **kwargs: Any) -> Any: ...
    def remove_column(self, *args: Any, **kwargs: Any) -> Any: ...
    def take(self, *args: Any, **kwargs: Any) -> Any: ...
    def to_pandas(self, *args: Any, **kwargs: Any) -> Any: ...


class PandasDataFrame(Protocol):
    """
    Protocol to match pandas dataframes without needing pandas-stubs installed.

    Only use for function arguments, not return types.
    """

    def where(self, *args: Any, **kwargs: Any) -> Any: ...
    def groupby(self, *args: Any, **kwargs: Any) -> Any: ...
    def unstack(self, *args: Any, **kwargs: Any) -> Any: ...
    def pivot_table(self, *args: Any, **kwargs: Any) -> Any: ...


class PandasSeries(Protocol):
    """
    Protocol to match pandas series without needing pandas-stubs installed.

    Only use for function arguments, not return types.
    """

    def to_frame(self, *args: Any, **kwargs: Any) -> Any: ...
    def isna(self, *args: Any, **kwargs: Any) -> Any: ...
    def rename_axis(self, *args: Any, **kwargs: Any) -> Any: ...


class PandasIndex(Protocol):
    """
    Protocol to match pandas indexes without needing pandas-stubs installed.

    Only use for function arguments, not return types.
    """

    def to_series(self, *args: Any, **kwargs: Any) -> Any: ...
    def isna(self, *args: Any, **kwargs: Any) -> Any: ...


class TorchTensor(Protocol):
    """
    Protocol to match PyTorch tensors without needing PyTorch installed.

    Only use for function arguments, not return types.
    """

    def cuda(self, *args: Any, **kwargs: Any) -> Any: ...
    def backward(self, *args: Any, **kwargs: Any) -> Any: ...


# Data types
PolarsDataType: TypeAlias = Union["DataTypeClass", "DataType"]
PolarsTemporalType: TypeAlias = Union[type["TemporalType"], "TemporalType"]
PolarsIntegerType: TypeAlias = Union[type["IntegerType"], "IntegerType"]
OneOrMoreDataTypes: TypeAlias = PolarsDataType | Iterable[PolarsDataType]
PythonDataType: TypeAlias = (
    type[int]
    | type[float]
    | type[bool]
    | type[str]
    | type["date"]
    | type["time"]
    | type["datetime"]
    | type["timedelta"]
    | type[list[Any]]
    | type[tuple[Any, ...]]
    | type[bytes]
    | type[object]
    | type["Decimal"]
    | type[None]
)

SchemaDefinition: TypeAlias = (
    Mapping[str, PolarsDataType | PythonDataType | None]
    | Sequence[str | tuple[str, PolarsDataType | PythonDataType | None]]
)
SchemaDict: TypeAlias = Mapping[str, PolarsDataType]

NumericLiteral: TypeAlias = Union[int, float, "Decimal"]
TemporalLiteral: TypeAlias = Union["date", "time", "datetime", "timedelta"]
NonNestedLiteral: TypeAlias = NumericLiteral | TemporalLiteral | str | bool | bytes
# Python literal types (can convert into a `lit` expression)
PythonLiteral: TypeAlias = Union[NonNestedLiteral, "np.ndarray[Any, Any]", list[Any]]
# Inputs that can convert into a `col` expression
IntoExprColumn: TypeAlias = Union["Expr", "Series", str]
# Inputs that can convert into an expression
IntoExpr: TypeAlias = PythonLiteral | IntoExprColumn | None

ComparisonOperator: TypeAlias = Literal["eq", "neq", "gt", "lt", "gt_eq", "lt_eq"]
Alignment: TypeAlias = Literal["left", "center", "right", "LEFT", "CENTER", "RIGHT"]

# selector type, and related collection/sequence
SelectorType: TypeAlias = "Selector"
ColumnNameOrSelector: TypeAlias = Union["str", SelectorType]

# User-facing string literal types
# The following all have an equivalent Rust enum with the same name
Ambiguous: TypeAlias = Literal["earliest", "latest", "raise", "null"]
AvroCompression: TypeAlias = Literal["uncompressed", "snappy", "deflate"]
CsvQuoteStyle: TypeAlias = Literal["necessary", "always", "non_numeric", "never"]
CategoricalOrdering: TypeAlias = Literal["physical", "lexical"]
CsvCompression: TypeAlias = Literal["uncompressed", "gzip", "zstd"]
CsvEncoding: TypeAlias = Literal["utf8", "utf8-lossy"]
ColumnMapping: TypeAlias = tuple[
    Literal["iceberg-column-mapping"],
    # This is "pa.Schema". Not typed as that causes pyright strict type checking
    # failures for users who don't have pyarrow-stubs installed.
    Any,
]
DefaultFieldValues: TypeAlias = tuple[
    Literal["iceberg"], dict[int, Union["Series", str]]
]
DeletionFiles: TypeAlias = (
    tuple[Literal["iceberg-position-delete"], dict[int, list[str]]]
    | tuple[Literal["delta-deletion-vector"], Callable[["DataFrame"], "DataFrame"]]
)
FillNullStrategy: TypeAlias = Literal[
    "forward", "backward", "min", "max", "mean", "zero", "one"
]
FloatFmt: TypeAlias = Literal["full", "mixed"]
IndexOrder: TypeAlias = Literal["c", "fortran"]
IpcCompression: TypeAlias = Literal["uncompressed", "lz4", "zstd"]
JoinValidation: TypeAlias = Literal["m:m", "m:1", "1:m", "1:1"]
Label: TypeAlias = Literal["left", "right", "datapoint"]
MaintainOrderJoin: TypeAlias = Literal[
    "none", "left", "right", "left_right", "right_left"
]
NdjsonCompression: TypeAlias = Literal["uncompressed", "gzip", "zstd"]
NonExistent: TypeAlias = Literal["raise", "null"]
NullBehavior: TypeAlias = Literal["ignore", "drop"]
ParallelStrategy: TypeAlias = Literal[
    "auto", "columns", "row_groups", "prefiltered", "none"
]
ParquetCompression: TypeAlias = Literal[
    "lz4", "uncompressed", "snappy", "gzip", "brotli", "zstd"
]
PivotAgg: TypeAlias = Literal[
    "min", "max", "first", "last", "sum", "mean", "median", "len", "item"
]
QuantileMethod: TypeAlias = Literal[
    "nearest", "higher", "lower", "midpoint", "linear", "equiprobable"
]
RankMethod: TypeAlias = Literal["average", "min", "max", "dense", "ordinal", "random"]
Roll: TypeAlias = Literal["raise", "forward", "backward"]
RoundMode: TypeAlias = Literal["half_to_even", "half_away_from_zero", "to_zero"]
SerializationFormat: TypeAlias = Literal["binary", "json"]
Endianness: TypeAlias = Literal["little", "big"]
SizeUnit: TypeAlias = Literal[
    "b",
    "kb",
    "mb",
    "gb",
    "tb",
    "bytes",
    "kilobytes",
    "megabytes",
    "gigabytes",
    "terabytes",
]
StartBy: TypeAlias = Literal[
    "window",
    "datapoint",
    "monday",
    "tuesday",
    "wednesday",
    "thursday",
    "friday",
    "saturday",
    "sunday",
]
SyncOnCloseMethod: TypeAlias = Literal["data", "all"]
TimeUnit: TypeAlias = Literal["ns", "us", "ms"]
UnicodeForm: TypeAlias = Literal["NFC", "NFKC", "NFD", "NFKD"]
UniqueKeepStrategy: TypeAlias = Literal["first", "last", "any", "none"]
UnstackDirection: TypeAlias = Literal["vertical", "horizontal"]
MapElementsStrategy: TypeAlias = Literal["thread_local", "threading"]

# The following have a Rust enum equivalent with a different name
AsofJoinStrategy: TypeAlias = Literal["backward", "forward", "nearest"]  # AsofStrategy
ClosedInterval: TypeAlias = Literal["left", "right", "both", "none"]  # ClosedWindow
InterpolationMethod: TypeAlias = Literal["linear", "nearest"]
JoinStrategy: TypeAlias = Literal[
    "inner", "left", "right", "full", "semi", "anti", "cross", "outer"
]  # JoinType
ListToStructWidthStrategy: TypeAlias = Literal["first_non_null", "max_width"]

# The following have no equivalent on the Rust side
ConcatMethod = Literal[
    "vertical",
    "vertical_relaxed",
    "diagonal",
    "diagonal_relaxed",
    "horizontal",
    "align",
    "align_full",
    "align_inner",
    "align_left",
    "align_right",
]
CorrelationMethod: TypeAlias = Literal["pearson", "spearman"]
DbReadEngine: TypeAlias = Literal["adbc", "connectorx"]
DbWriteEngine: TypeAlias = Literal["sqlalchemy", "adbc"]
DbWriteMode: TypeAlias = Literal["replace", "append", "fail"]
EpochTimeUnit = Literal["ns", "us", "ms", "s", "d"]
JaxExportType: TypeAlias = Literal["array", "dict"]
Orientation: TypeAlias = Literal["col", "row"]
SearchSortedSide: TypeAlias = Literal["any", "left", "right"]
TorchExportType: TypeAlias = Literal["tensor", "dataset", "dict"]
TransferEncoding: TypeAlias = Literal["hex", "base64"]
WindowMappingStrategy: TypeAlias = Literal["group_to_rows", "join", "explode"]
ExplainFormat: TypeAlias = Literal["plain", "tree"]

# type signature for allowed series init
ArrayLike: TypeAlias = Union[
    Iterable[Any],
    "Series",
    "PyArrowArray",
    "PyArrowChunkedArray",
    "NumpyArray",
    "PandasSeries",
    "PandasIndex",
    "ArrowArrayExportable",
    "ArrowStreamExportable",
]


# type signature for allowed frame init
FrameInitTypes: TypeAlias = Union[
    Mapping[str, ArrayLike | NonNestedLiteral | None],
    Iterable[Any],
    NumpyArray,
    PyArrowTable,
    PandasDataFrame,
    "ArrowArrayExportable",
    "ArrowStreamExportable",
    TorchTensor,
    "DataFrame",
]

# Excel IO
ColumnFormatDict: TypeAlias = Mapping[
    # dict of colname(s) or selector(s) to format string or dict
    ColumnNameOrSelector | tuple[ColumnNameOrSelector, ...],
    str | Mapping[str, str],
]
ConditionalFormatDict: TypeAlias = Mapping[
    # dict of colname(s) to str, dict, or sequence of str/dict
    ColumnNameOrSelector | Collection[str],
    str | Mapping[str, Any] | Sequence[str | Mapping[str, Any]],
]
ColumnTotalsDefinition: TypeAlias = (
    Mapping[ColumnNameOrSelector | tuple[ColumnNameOrSelector], str]
    | Sequence[str]
    | bool
)
ColumnWidthsDefinition: TypeAlias = (
    Mapping[ColumnNameOrSelector, tuple[str, ...] | int] | int
)
RowTotalsDefinition: TypeAlias = (
    Mapping[str, str | Collection[str]] | Collection[str] | bool
)

# standard/named hypothesis profiles used for parametric testing
ParametricProfileNames: TypeAlias = Literal["fast", "balanced", "expensive"]

# typevars for core polars types
PolarsType = TypeVar("PolarsType", "DataFrame", "LazyFrame", "Series", "Expr")
FrameType = TypeVar("FrameType", "DataFrame", "LazyFrame")
BufferInfo: TypeAlias = tuple[int, int, int]

# type alias for supported spreadsheet engines
ExcelSpreadsheetEngine: TypeAlias = Literal["calamine", "openpyxl", "xlsx2csv"]


class SeriesBuffers(TypedDict):
    """Underlying buffers of a Series."""

    values: Series
    validity: Series | None
    offsets: Series | None


# minimal protocol definitions that can reasonably represent
# an executable connection, cursor, or equivalent object
class BasicConnection(Protocol):
    def cursor(self, *args: Any, **kwargs: Any) -> Any:
        """Return a cursor object."""


class BasicCursor(Protocol):
    def execute(self, *args: Any, **kwargs: Any) -> Any:
        """Execute a query."""


class Cursor(BasicCursor):
    def fetchall(self, *args: Any, **kwargs: Any) -> Any:
        """Fetch all results."""

    def fetchmany(self, *args: Any, **kwargs: Any) -> Any:
        """Fetch results in batches."""


AlchemyConnection: TypeAlias = Union["Connection", "Engine", "Session"]
AlchemyAsyncConnection: TypeAlias = Union[
    "AsyncConnection", "AsyncEngine", "AsyncSession", "async_sessionmaker[AsyncSession]"
]
ConnectionOrCursor: TypeAlias = (
    BasicConnection | BasicCursor | Cursor | AlchemyConnection | AlchemyAsyncConnection
)

# Annotations for `__getitem__` methods
SingleIndexSelector: TypeAlias = int
MultiIndexSelector: TypeAlias = Union[
    slice,
    range,
    Sequence[int],
    "Series",
    "np.ndarray[Any, Any]",
]
SingleNameSelector: TypeAlias = str
MultiNameSelector: TypeAlias = Union[
    slice,
    Sequence[str],
    "Series",
    "np.ndarray[Any, Any]",
]
BooleanMask: TypeAlias = Union[
    Sequence[bool],
    "Series",
    "np.ndarray[Any, Any]",
]
SingleColSelector: TypeAlias = SingleIndexSelector | SingleNameSelector
MultiColSelector: TypeAlias = MultiIndexSelector | MultiNameSelector | BooleanMask

# LazyFrame engine selection
EngineType: TypeAlias = Union[
    Literal["auto", "in-memory", "streaming", "gpu"], "GPUEngine"
]

PlanStage: TypeAlias = Literal["ir", "physical"]

FileSource: TypeAlias = (
    str
    | Path
    | IO[bytes]
    | bytes
    | list[str]
    | list[Path]
    | list[IO[bytes]]
    | list[bytes]
)

JSONEncoder = Callable[[Any], bytes] | Callable[[Any], str]

DeprecationType: TypeAlias = Literal[
    "function",
    "renamed_parameter",
    "streaming_parameter",
    "nonkeyword_arguments",
    "parameter_as_multi_positional",
]


__all__ = [
    "Alignment",
    "Ambiguous",
    "ArrowArrayExportable",
    "ArrowStreamExportable",
    "AsofJoinStrategy",
    "AvroCompression",
    "BooleanMask",
    "BufferInfo",
    "CategoricalOrdering",
    "ClosedInterval",
    "ColumnFormatDict",
    "ColumnNameOrSelector",
    "ColumnTotalsDefinition",
    "ColumnWidthsDefinition",
    "ComparisonOperator",
    "ConcatMethod",
    "ConditionalFormatDict",
    "ConnectionOrCursor",
    "CorrelationMethod",
    "CsvEncoding",
    "CsvQuoteStyle",
    "Cursor",
    "DbReadEngine",
    "DbWriteEngine",
    "DbWriteMode",
    "DeprecationType",
    "Endianness",
    "EngineType",
    "EpochTimeUnit",
    "ExcelSpreadsheetEngine",
    "ExplainFormat",
    "FileSource",
    "FillNullStrategy",
    "FloatFmt",
    "FrameInitTypes",
    "FrameType",
    "IndexOrder",
    "InterpolationMethod",
    "IntoExpr",
    "IntoExprColumn",
    "IpcCompression",
    "JSONEncoder",
    "JaxExportType",
    "JoinStrategy",
    "JoinValidation",
    "Label",
    "ListToStructWidthStrategy",
    "MaintainOrderJoin",
    "MapElementsStrategy",
    "MultiColSelector",
    "MultiIndexSelector",
    "MultiNameSelector",
    "NdjsonCompression",
    "NonExistent",
    "NonNestedLiteral",
    "NullBehavior",
    "NumericLiteral",
    "OneOrMoreDataTypes",
    "Orientation",
    "ParallelStrategy",
    "ParametricProfileNames",
    "ParquetCompression",
    "PivotAgg",
    "PolarsDataType",
    "PolarsIntegerType",
    "PolarsTemporalType",
    "PolarsType",
    "PythonDataType",
    "PythonLiteral",
    "QuantileMethod",
    "RankMethod",
    "Roll",
    "RowTotalsDefinition",
    "SchemaDefinition",
    "SchemaDict",
    "SearchSortedSide",
    "SelectorType",
    "SerializationFormat",
    "SeriesBuffers",
    "SingleColSelector",
    "SingleIndexSelector",
    "SingleNameSelector",
    "SizeUnit",
    "StartBy",
    "SyncOnCloseMethod",
    "TemporalLiteral",
    "TimeUnit",
    "TorchExportType",
    "TransferEncoding",
    "UnicodeForm",
    "UniqueKeepStrategy",
    "UnstackDirection",
    "WindowMappingStrategy",
]


class ParquetMetadataContext:
    """
    The context given when writing file-level parquet metadata.

    .. warning::
        This functionality is considered **experimental**. It may be removed or
        changed at any point without it being considered a breaking change.
    """

    def __init__(self, *, arrow_schema: str) -> None:
        self.arrow_schema = arrow_schema

    arrow_schema: str  #: The base64 encoded arrow schema that is going to be written into metadata.


ParquetMetadataFn: TypeAlias = Callable[[ParquetMetadataContext], dict[str, str]]
ParquetMetadata: TypeAlias = dict[str, str] | ParquetMetadataFn

StorageOptionsDict: TypeAlias = dict[str, Any]
