from fastapi import FastAPI, Query
from fastapi.responses import HTMLResponse
from pathlib import Path
from typing import Optional
import json
from graph import build_graph
from ml import predict_joinability, best_path
from join_generator import generate_multi_join_sql
from semantic_rename import suggest_column_name
from pydantic import BaseModel
from semantic_rename import update_column_name
from table_rename import update_table_name
from schema_cache import (
    get_cached_tables,
    get_cached_columns
)
from report_generator import generate_report_sql
from query_interpreter import interpret
from llm_interpreter import llm_interpret
from json_to_report import json_to_report
from join_generator import generate_multi_join_sql
from ml import best_paths
from sqlalchemy import text
from db import engine
from join_generator import find_join_key
from join_generator import merge_paths
from bi_engine.semantic_loader import semantic_catalog




app = FastAPI()

@app.get("/", response_class=HTMLResponse)
def home():

    html_file = Path("templates/index.html")

    if not html_file.exists():
        return "<h2>templates/index.html not found</h2>"

    return html_file.read_text(
        encoding="utf-8"
    )

class RenameRequest(BaseModel):
    raw: str
    readable: str

class TableRenameRequest(BaseModel):
    raw: str
    readable: str


from sqlalchemy import text
from db import engine
from pydantic import BaseModel

class BIQuery(BaseModel):
    entity: str
    metric: str


class AskRequest(BaseModel):
    question: str

class BIQuery(BaseModel):
    entity: str
    report_type: str = "aggregate"
    metric: str | None = None
    field: str | None = None
    group_by: str | None = None
    sort: str | None = None
    limit: int | None = None
    filters: dict = {}

@app.post("/ask")
def ask(req: AskRequest):

    print("QUESTION:", req.question)

    llm_response = llm_interpret(req.question)

    print("LLM RESPONSE:", llm_response)

    if not llm_response["success"]:
        return llm_response

    data = llm_response["data"]

    print("DATA:", data)

    if data.get("group_by") == "":
     data["group_by"] = None

    if data.get("sort") == "":
        data["sort"] = None

    if data.get("field") == "":
        data["field"] = None

    if (
    data.get("report_type") == "details"
    and "count" in req.question.lower()
    ):
        data.pop("report_type", None)
        data["metric"] = "count"   

    bi_req = BIQuery(
    entity=data["entity"],
    report_type=data.get(
        "report_type",
        "aggregate"
    ),
    metric=data.get("metric"),
    field=data.get("field"),
    group_by=data.get("group_by"),
    sort=data.get("sort"),
    limit=data.get("limit"),
    filters=data.get("filters", {})
  )

    print("BI REQ:", bi_req)
    return bi_query(bi_req)

@app.post("/bi_query")
def bi_query(req: BIQuery):

    entity_info = semantic_catalog.get_entity(
        req.entity
    )

    if not entity_info:
        return {
            "success": False,
            "message": f"Unknown entity: {req.entity}"
        }

    table = entity_info["table"]

    #
    # DETAIL REPORTS
    #

    base_filter = entity_info.get(
        "base_filter",
        {}
    )

    if req.report_type == "details":

        conditions = []
        filter_columns = []

        for filter_name in req.filters:

            filter_info = entity_info.get(
                "filters",
                {}
            ).get(filter_name)

            if filter_info:
                filter_columns.append(
                    filter_info["column"]
                )

        # paths = []

        # if filter_columns:
        #     paths = best_paths(
        #         table,
        #         filter_columns
        #     )

        # join_sql = merge_paths(paths)

        sql = f"""
        SELECT {table}.*
        FROM {table}
        """

                #
                # District filter
                #

        if req.filters:
                    
            targets = []
            where_filters = {}  
            for filter_name, filter_value in req.filters.items():

                        filter_info = entity_info.get(
                            "filters",
                            {}
                        ).get(filter_name)

                        if not filter_info:
                            continue

                        target_table = filter_info["table"]
                        target_column = filter_info["column"]

                        if target_table not in targets:
                            targets.append(target_table)

                        where_filters.setdefault(
                            target_table,
                            {}
                        )

                        if (
                            filter_name == "status_level"
                            and filter_value == "CARBON_DOCUMENTS_APPROVED"
                        ):

                            where_filters[target_table][target_column] = {
                                "op": "IN",
                                "value": [
                                    "CARBON_DOCUMENTS_APPROVED",
                                    "CARBON_DOCUMENTS_APPROVED_"
                                ]
                            }

                        else:

                            where_filters[target_table][target_column] = {
                                "op": "=",
                                "value": filter_value
                            }
                        if (
                            filter_name == "status_level"
                            and filter_value == "CARBON_DOCUMENTS_APPROVED"
                        ):

                            conditions.append(
                                f"{target_table}.{target_column} IN "
                                "('CARBON_DOCUMENTS_APPROVED', "
                                "'CARBON_DOCUMENTS_APPROVED_')"
                            )

                        else:

                            conditions.append(
                                f"{target_table}.{target_column} = '{filter_value}'"
                            )
                        
                        
                     
            print("TARGETS:", targets)

            paths = best_paths(
                table,
                targets
            )

            print("DETAIL PATHS:", paths)
            edges = merge_paths(paths)
            print("DETAIL EDGES:", edges)

            sql = f"""
            SELECT {table}.*
            FROM {table}
            """

            for left_table, right_table in edges:

                left_col, right_col = find_join_key(
                    left_table,
                    right_table
                )

                sql += f"""
             JOIN {right_table}
                ON {left_table}.{left_col}
                = {right_table}.{right_col}
            """

        if base_filter:

            for col, value in base_filter.items():

                conditions.append(
                    f'{table}."{col}" = {value}'
                )
        if conditions:
                sql += "\nWHERE " + " AND ".join(conditions)

        print("\nDETAIL SQL:")
        print(sql)
        print("\n")

        with engine.connect() as conn:

            rows = conn.execute(
                text(sql)
            )

            result = [
                dict(r._mapping)
                    for r in rows
                ]

        return {
            "success": True,
            "entity": req.entity,
            "report_type": req.report_type,
            "filters": req.filters,
            "sql": sql,
            "data": result
        }           

        #
        # Geo Type filter
    #
    # AGGREGATE REPORTS
    #

    base_filter = entity_info.get(
        "base_filter",
        {}
    )

    where_filters = {}
    targets = []
    group_by = None

    #
    # GROUP BY SUPPORT
    #

    if req.group_by == "district":

        targets.extend([
            "agri_user_address",
            "agri_districts"
        ])

        group_by = [
            ("agri_districts", "district_name")
        ]

    #
    # DYNAMIC FILTERS
    #

    for filter_name, filter_value in req.filters.items():

        filter_info = entity_info.get(
            "filters",
            {}
        ).get(filter_name)

        if not filter_info:
            continue

        target_table = filter_info["table"]
        target_column = filter_info["column"]

        if target_table not in targets:
            targets.append(target_table)

        where_filters.setdefault(
            target_table,
            {}
        )
        

        if (
            filter_name == "status_level"
            and filter_value == "CARBON_DOCUMENTS_APPROVED"
        ):

            where_filters[target_table][target_column] = {
                "op": "IN",
                "value": [
                    "CARBON_DOCUMENTS_APPROVED",
                    "CARBON_DOCUMENTS_APPROVED_"
                ]
            }

        else:

            where_filters[target_table][target_column] = {
                "op": "=",
                "value": filter_value
            } 

        # where_filters[target_table][target_column] = {
        #     "op": "=",
        #     "value": filter_value
        #}

    #
    # REMOVE DUPLICATES
    #

    targets = list(
        dict.fromkeys(targets)
    )

    #
    # BASE FILTERS FROM CATALOG
    #

    if base_filter:

        where_filters.setdefault(
            table,
            {}
        )

        for col, value in base_filter.items():

            where_filters[table][col] = {
                "op": "=",
                "value": value
            }

    #
    # METRIC SUPPORT
    #
    if req.metric == "count":

        metric_column = entity_info.get(
            "count_column",
            entity_info.get("primary_key")
        )

        metric_table = table

        agg_function = "count"

    elif req.metric == "sum":

        measure_info = (
            entity_info
            .get("measure_columns", {})
            .get(req.field)
        )

        if measure_info:

            metric_table = measure_info["table"]
            metric_column = measure_info["column"]

            if metric_table not in targets:
                targets.append(metric_table)

        else:

            metric_table = table
            metric_column = req.field

        agg_function = "sum"

    else:

        return {
            "success": False,
            "message": f"Unsupported metric: {req.metric}"
        } 
    # if req.metric == "count":

    #     metric_column = entity_info.get(
    #         "count_column",
    #         entity_info.get("primary_key")
    #     )

    #     agg_function = "count"

    # elif req.metric == "sum":

    #             measure_info = (
    #             entity_info
    #             .get("measure_columns", {})
    #             .get(req.field)
    #             )

    #     if measure_info:

    #         metric_table = measure_info["table"]
    #         metric_column = measure_info["column"]

    #         if metric_table not in targets:
    #             targets.append(metric_table)

    #     else:

    #         metric_table = table
    #         metric_column = req.field

    #     agg_function = "sum"

    # else:

    #     return {
    #         "success": False,
    #         "message": f"Unsupported metric: {req.metric}"
    #     }

    #
    # COLUMN SELECTION
    #

    column_filters = {
        metric_table: [metric_column]
    }

    if req.group_by == "district":

        column_filters[
            "agri_districts"
        ] = [
            "district_name"
        ]

    #
    # GEO TYPE FILTER
    #

    if "geo_type" in req.filters:

        where_filters.setdefault(
            table,
            {}
        )

        if req.filters["geo_type"] == "non_geofenced":

            where_filters[table]["geo_type"] = {
                "op": "!=",
                "value": "geofenced"
            }

        elif req.filters["geo_type"] == "geofenced":

            where_filters[table]["geo_type"] = {
                "op": "=",
                "value": "geofenced"
            }

    #
    # AGGREGATIONS
    #

    aggregations = {
        metric_table: {
            metric_column: agg_function
        }
    }

    print("WHERE FILTERS:", where_filters)
    print("COLUMN FILTERS:", column_filters)
    print("AGGREGATIONS:", aggregations)
    #
    # GENERATE SQL
    #

    sql = generate_multi_join_sql(
        source=table,
        targets=targets,
        column_filters=column_filters,
        where_filters=where_filters,
        group_by=group_by,
        aggregations=aggregations
    )

    #
    # DEFAULT SORT FOR SUM REPORTS
    #

    if (
        req.group_by == "district"
        and req.metric == "sum"
        and not req.sort
    ):
        sql += "\nORDER BY sum_land_size DESC"

    #
    # TOP-N REPORTS
    #

    if (
        req.group_by == "district"
        and req.sort == "desc"
    ):

        if req.metric == "count":

            sql += (
                "\nORDER BY count_"
                + metric_column
                + " DESC"
            )

        elif req.metric == "sum":

            sql += (
                "\nORDER BY sum_"
                + metric_column
                + " DESC"
            )

    if req.limit:

        sql += f"\nLIMIT {req.limit}"

    print("\nGENERATED SQL:")
    print(sql)
    print("\n")

    #
    # EXECUTE
    #

    with engine.connect() as conn:

        rows = conn.execute(
            text(sql)
        )

        result = [
            dict(r._mapping)
            for r in rows
        ]

    return {
        "success": True,
        "entity": req.entity,
        "metric": req.metric,
        "group_by": req.group_by,
        "filters": req.filters,
        "sql": sql,
        "data": result
    }


   

@app.get("/bi_test")
def bi_test():

    sql = generate_multi_join_sql(
        source="agri_users",
        targets=[],
        column_filters={
            "agri_users": ["user_id"]
        },
        aggregations={
            "agri_users": {
                "user_id": "count"
            }
        }
    )

    with engine.connect() as conn:

        rows = conn.execute(text(sql))

        result = [
            dict(r._mapping)
            for r in rows
        ]

    return {
        "sql": sql,
        "data": result
    }    

@app.get("/test_farmer_count")
def test_farmer_count():

    sql = generate_multi_join_sql(
        source="agri_users",
        targets=[],
        column_filters={
            "agri_users": ["user_id"]
        },
        aggregations={
            "agri_users": {
                "user_id": "count"
            }
        }
    )

    return {
        "sql": sql
    }

@app.get("/tables")
def tables():
    return get_cached_tables()

@app.get("/columns")
def columns():
    return get_cached_columns()

@app.get("/graph")
def graph():
    G = build_graph()
    return {"nodes": list(G.nodes), "edges": list(G.edges)}

@app.get("/predict_join")
def predict(t1: str, t2: str):
    score = predict_joinability(t1, t2)
    return {"joinability_score": score}

@app.get("/best_path")
def path(source: str, target: str):
    return {"best_path": best_path(source, target)}

@app.get("/generate_multi_join_sql_filtered")
def multi_join_sql_filtered(
    source: str,
    targets: str,
    columns: Optional[str] = Query(None, description="Format: table1.col1,col2;table2.col3,col4"),
    where: Optional[str] = Query(None, description="JSON: {\"table\": {\"col\": {\"op\": \"=\", \"value\": \"val\"}}}"),
    order: Optional[str] = Query(None, description="JSON: [[\"table\", \"col\", \"asc\"]]"),
    limit: Optional[int] = Query(None),
    offset: Optional[int] = Query(None)
):
    target_list = targets.split(",")
    column_filters = {}
    if columns:
        for group in columns.split(";"):
            parts = group.split(".")
            table = parts[0]
            cols = parts[1].split(",") if len(parts) > 1 else ["*"]
            column_filters[table] = cols

    where_filters = json.loads(where) if where else {}
    order_by = json.loads(order) if order else []

    sql = generate_multi_join_sql(
        source,
        target_list,
        column_filters,
        where_filters,
        order_by,
        limit,
        offset
    )
    return {"sql": sql}

@app.get("/suggest_column_name")
def suggest(raw: str):
    name, score = suggest_column_name(raw)
    return {"suggested_name": name, "confidence": round(score, 3)}

@app.get("/columns_with_suggestions")
def columns_with_suggestions():
    raw = get_cached_columns()
    result = {}
    for table, cols in raw.items():
        result[table] = [
            {"original": col, "suggested": suggest_column_name(col)[0]}
            for col in cols
        ]
    return result

@app.post("/rename_column")
def rename_column(req: RenameRequest):
    update_column_name(req.raw, req.readable)
    return {"message": f"Renamed '{req.raw}' to '{req.readable}'"}

@app.post("/rename_table")
def rename_table_api(req: TableRenameRequest):
    update_table_name(req.raw, req.readable)
    return {"message": f"Renamed table '{req.raw}' to '{req.readable}'"}

@app.get("/generate_report")
def generate_report(
    metric: str,
    dimension: str,
    limit: int = 20
):

    return generate_report_sql(
        metric=metric,
        dimension=dimension,
        limit=limit
    )

@app.post("/refresh_schema")
def refresh_schema():

    get_cached_tables.cache_clear()
    get_cached_columns.cache_clear()

    get_cached_tables()
    get_cached_columns()

    return {"message": "Schema cache refreshed"}

@app.post("/understand_query_rule")
def understand_query_rule(query: str):

    return interpret(query)


class QueryRequest(BaseModel):
    query: str    

@app.post("/understand_query")
def understand_query(req: QueryRequest):

    return llm_interpret(req.query)


@app.post("/generate_report_definition")
def generate_report_definition(req: QueryRequest):

    llm_result = llm_interpret(req.query)

    if not llm_result["success"]:
        return llm_result

    return json_to_report(
        llm_result["data"]
    )

@app.post("/generate_sql_from_query")
def generate_sql_from_query(req: QueryRequest):

    llm_result = llm_interpret(req.query)

    if not llm_result["success"]:
        return llm_result

    report = json_to_report(
        llm_result["data"]
    )

    sql = generate_multi_join_sql(
        source=report["source"],
        targets=report["targets"],
        column_filters=report["column_filters"],
        where_filters=report["where_filters"]
    )

    return {
        "success": True,
        "sql": sql,
        "report_definition": report
    }    
@app.get("/debug_path")
def debug_path():

    return {
        "district_path":
            best_paths(
                "agri_users",
                ["agri_districts"]
            )
    }


@app.get("/debug_paths")
def debug_paths():

    targets = [
        "agri_districts",
        "agri_crop_history",
        "agri_crop_user_rels"
    ]

    return {
        "paths": best_paths(
            "agri_users",
            targets
        )
    }


@app.get("/debug_columns")
def debug_columns():

    from schema_cache import get_cached_columns

    cols = get_cached_columns()

    return {
        "agri_users": cols.get("agri_users"),
        "agri_districts": cols.get("agri_districts"),
        "agri_user_address": cols.get("agri_user_address"),
        "agri_land": cols.get("agri_land")
    }

@app.get("/debug_path_address")
def debug_path_address():

    return {
        "path": best_paths(
            "agri_users",
            ["agri_user_address"]
        )
    }


@app.get("/debug_shared/{t1}/{t2}")
def debug_shared(t1: str, t2: str):

    cols = get_cached_columns()

    return {
        "shared": list(
            set(cols[t1]) &
            set(cols[t2])
        )
    }




@app.post("/execute_query")
def execute_query(req: QueryRequest):

    llm_result = llm_interpret(req.query)

    if not llm_result["success"]:
        return llm_result

    report = json_to_report(
        llm_result["data"]
    )

    sql = generate_multi_join_sql(
        source=report["source"],
        targets=report["targets"],
        column_filters=report["column_filters"],
        where_filters=report["where_filters"]
    )

    try:

        with engine.connect() as conn:

            rows = conn.execute(
                text(sql)
            )

            result = [
                dict(row._mapping)
                for row in rows.fetchmany(20)
            ]

        return {
            "success": True,
            "sql": sql,
            "count": len(result),
            "data": result
        }

    except Exception as e:

        return {
            "success": False,
            "sql": sql,
            "error": str(e)
        }
    


@app.get("/sample_user_address")
def sample_user_address():

    with engine.connect() as conn:

        rows = conn.execute(
            text("""
                SELECT
                    user_id,
                    district
                FROM agri_user_address
                LIMIT 20
            """)
        )

        return [
            dict(r._mapping)
            for r in rows
        ]  

@app.get("/debug_path_district")
def debug_path_district():

    return {
        "path": best_paths(
            "agri_user_address",
            ["agri_districts"]
        )
    }  

@app.get("/debug_shared_address_district")
def debug_shared_address_district():

    cols = get_cached_columns()

    return {
        "agri_user_address": cols["agri_user_address"],
        "agri_districts": cols["agri_districts"]
    } 


@app.get("/debug_full_path")
def debug_full_path():

    return {
        "paths": best_paths(
            "agri_users",
            [
                "agri_user_address",
                "agri_districts",
                "agri_crop_history",
                "agri_crop_user_rels"
            ]
        )
    }

@app.get("/debug_columns/{table}")
def debug_columns(table: str):

    from schema_cache import get_cached_columns

    cols = get_cached_columns()

    return {
        "table": table,
        "columns": cols.get(table, [])
    }

@app.get("/search_column/{column_name}")
def search_column(column_name: str):

    from schema_cache import get_cached_columns

    cols = get_cached_columns()

    result = {}

    for table, columns in cols.items():

        matches = [
            c for c in columns
            if column_name.lower() in c.lower()
        ]

        if matches:
            result[table] = matches

    return result


from sqlalchemy import text
from db import engine

@app.get("/run_sql")
def run_sql(sql: str):

    try:

        with engine.connect() as conn:

            result = conn.execute(
                text(sql)
            )

            rows = [
                dict(r._mapping)
                for r in result.fetchall()
            ]

        return {
            "success": True,
            "count": len(rows),
            "data": rows
        }

    except Exception as e:

        return {
            "success": False,
            "error": str(e)
        }    

@app.get("/search_tables/{keyword}")
def search_tables(keyword: str):

    from schema_cache import get_cached_tables

    return [
        t
        for t in get_cached_tables()
        if keyword.lower() in t.lower()
    ]        









# @app.post("/bi_query")
# def bi_query(req: BIQuery):

#     entity_info = semantic_catalog.get_entity(
#         req.entity
#     )

#     if not entity_info:
#         return {
#             "success": False,
#             "message": f"Unknown entity: {req.entity}"
#         }

#     table = entity_info["table"]

#      #
# # DETAIL REPORTS
# #

# if req.report_type == "details":

#     sql = f"""
# SELECT *
# FROM {table}
# """

#     conditions = []

#     if (
#         req.entity == "plot"
#         and req.filters.get("geo_type") == "non_geofenced"
#     ):
#         conditions.append(
#             "geo_type != 'geofenced'"
#         )

#     elif (
#         req.entity == "plot"
#         and req.filters.get("geo_type") == "geofenced"
#     ):
#         conditions.append(
#             "geo_type = 'geofenced'"
#         )

#     if conditions:

#         sql += "\nWHERE " + " AND ".join(
#             conditions
#         )

#     sql += "\nLIMIT 100"

#     with engine.connect() as conn:

#         rows = conn.execute(
#             text(sql)
#         )

#         result = [
#             dict(r._mapping)
#             for r in rows
#         ]

#     return {
#         "success": True,
#         "report_type": "details",
#         "entity": req.entity,
#         "sql": sql,
#         "data": result
#     }

#     base_filter = entity_info.get(
#         "base_filter",
#         {}
#     )

#     where_filters = {}
#     targets = []
#     group_by = None

#     #
#     # GROUP BY SUPPORT
#     #

#     if req.group_by == "district":

#         targets.extend([
#             "agri_user_address",
#             "agri_districts"
#         ])

#         group_by = [
#             ("agri_districts", "district_name")
#         ]

#     #
#     # DISTRICT FILTER
#     #

#     if "district" in req.filters:

#         targets.extend([
#             "agri_user_address",
#             "agri_districts"
#         ])

#         where_filters["agri_districts"] = {
#             "district_name": {
#                 "op": "=",
#                 "value": req.filters["district"]
#             }
#         }

#     #
#     # REMOVE DUPLICATES
#     #

#     targets = list(
#         dict.fromkeys(targets)
#     )

#     #
#     # BASE FILTERS FROM CATALOG
#     #

#     if base_filter:

#         where_filters.setdefault(
#             table,
#             {}
#         )

#         for col, value in base_filter.items():

#             where_filters[table][col] = {
#                 "op": "=",
#                 "value": value
#             }

#     #
#     # METRIC SUPPORT
#     #

#     if req.metric == "count":

#         metric_column = entity_info.get(
#             "count_column",
#             "user_id"
#         )

#         agg_function = "count"

#     elif req.metric == "sum":

#         metric_column = req.field

#         agg_function = "sum"

#     else:

#         return {
#             "success": False,
#             "message": f"Unsupported metric: {req.metric}"
#         }

#     #
#     # COLUMN SELECTION
#     #

#     column_filters = {
#         table: [metric_column]
#     }

#     if req.group_by == "district":

#         column_filters[
#             "agri_districts"
#         ] = [
#             "district_name"
#         ]

#     if "geo_type" in req.filters:

#         where_filters.setdefault(
#             table,
#             {}
#         )

#         if req.filters["geo_type"] == "non_geofenced":

#             where_filters[table]["geo_type"] = {
#                 "op": "!=",
#                 "value": "geofenced"
#             }

#         elif req.filters["geo_type"] == "geofenced":

#             where_filters[table]["geo_type"] = {
#                 "op": "=",
#                 "value": "geofenced"
#             }
#     #
#     # AGGREGATIONS
#     #

#     aggregations = {
#         table: {
#             metric_column: agg_function
#         }
#     }

#     #
#     # GENERATE SQL
#     #

#     sql = generate_multi_join_sql(
#         source=table,
#         targets=targets,
#         column_filters=column_filters,
#         where_filters=where_filters,
#         group_by=group_by,
#         aggregations=aggregations
#     )
#     if (
#     req.group_by == "district"
#     and req.metric == "sum"
#   ):
#      sql += "\nORDER BY sum_land_size DESC"
#     #
#     # TOP-N REPORTS
#     #

#     if (
#         req.group_by == "district"
#         and req.sort == "desc"
#     ):

#         if req.metric == "count":

#             sql += "\nORDER BY count_" + metric_column + " DESC"

#         elif req.metric == "sum":

#             sql += "\nORDER BY sum_" + metric_column + " DESC"

#     if req.limit:

#         sql += f"\nLIMIT {req.limit}"

#     print("\nGENERATED SQL:")
#     print(sql)
#     print("\n")

#     #
#     # EXECUTE
#     #

#     with engine.connect() as conn:

#         rows = conn.execute(
#             text(sql)
#         )

#         result = [
#             dict(r._mapping)
#             for r in rows
#         ]

#     return {
#         "success": True,
#         "entity": req.entity,
#         "metric": req.metric,
#         "group_by": req.group_by,
#         "filters": req.filters,
#         "sql": sql,
#         "data": result
#     }