from __future__ import annotations

import re
from typing import Any, List, Tuple

from app.datasource import DataSource
from app.models import FilterSpec, QueryIntent
from app.semantic import SemanticLayer


class SqlBuildError(ValueError):
    pass


def _metric_expr(semantic: SemanticLayer, entity: str, metric: str) -> str:
    ent = semantic.entities.get(entity)
    if not ent:
        raise SqlBuildError(f"Unknown entity: {entity}")
    m = ent.metrics.get(metric)
    if not m:
        raise SqlBuildError(f"Unknown metric '{metric}' for entity '{entity}'")
    return m.expr


def _dim_column(semantic: SemanticLayer, entity: str, dim: str) -> str:
    ent = semantic.entities.get(entity)
    if not ent:
        raise SqlBuildError(f"Unknown entity: {entity}")
    d = ent.dimensions.get(dim)
    if not d:
        raise SqlBuildError(f"Unknown dimension '{dim}' for entity '{entity}'")
    return d.column


def _time_column(semantic: SemanticLayer, entity: str) -> str:
    ent = semantic.entities.get(entity)
    if not ent:
        raise SqlBuildError(f"Unknown entity: {entity}")
    return ent.time_column


def build_sql(
    semantic: SemanticLayer,
    source: DataSource,
    intent: QueryIntent,
) -> Tuple[str, Tuple[Any, ...]]:
    ent = semantic.entities.get(intent.entity)
    if not ent:
        raise SqlBuildError(f"Unknown entity: {intent.entity}")

    table = source.qualify_table(ent.physical_table)
    time_col = ent.time_column
    params: List[Any] = []
    where_clauses: List[str] = []

    if intent.time_range:
        start, end = intent.time_range
        if start:
            where_clauses.append(f"{time_col} >= ?")
            params.append(start)
        if end:
            where_clauses.append(f"{time_col} <= ?")
            params.append(end)

    for f in intent.filters:
        col = _dim_column(semantic, intent.entity, f.dimension)
        if f.op == "=":
            where_clauses.append(f"{col} = ?")
            params.append(f.value)
        elif f.op == "!=":
            where_clauses.append(f"{col} != ?")
            params.append(f.value)
        elif f.op == "in":
            if not isinstance(f.value, (list, tuple)) or not f.value:
                raise SqlBuildError("IN filter requires non-empty list value")
            placeholders = ",".join("?" for _ in f.value)
            where_clauses.append(f"{col} IN ({placeholders})")
            params.extend(list(f.value))
        elif f.op == "between":
            where_clauses.append(f"{col} BETWEEN ? AND ?")
            params.extend([f.value, f.value_end])
        elif f.op in (">", "<", ">=", "<="):
            where_clauses.append(f"{col} {f.op} ?")
            params.append(f.value)
        else:
            raise SqlBuildError(f"Unsupported filter op: {f.op}")

    select_parts: List[str] = []
    group_parts: List[str] = []

    for d in intent.dimensions:
        c = _dim_column(semantic, intent.entity, d)
        select_parts.append(f"{c} AS {d}")
        group_parts.append(c)

    for m in intent.metrics:
        expr = _metric_expr(semantic, intent.entity, m)
        select_parts.append(f"{expr} AS {m}")

    if not select_parts:
        raise SqlBuildError("No metrics or dimensions in intent")

    where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
    group_sql = f"GROUP BY {', '.join(group_parts)}" if group_parts else ""

    order_sql = ""
    if intent.order_by:
        if not re.fullmatch(r"[A-Za-z_][A-Za-z0-9_]*", intent.order_by):
            raise SqlBuildError("Invalid order_by identifier")
        order_sql = f"ORDER BY {intent.order_by} DESC"

    lim = intent.limit or 500
    sql = (
        f"SELECT {', '.join(select_parts)} FROM {table} "
        f"{where_sql} {group_sql} {order_sql}".strip()
    )
    sql = f"{sql} LIMIT {lim}"
    return sql, tuple(params)
