import asyncio
from typing import List, Dict, Any
from sqlalchemy import func, select, cast, String, text, Integer, case
from sqlalchemy.orm import selectinload
from sqlalchemy.ext.asyncio import AsyncSession
from database.models import AILog, User
from utils.date_utils import get_date_range
from datetime import date


class AnalyticsService:
    """
    Service layer for handling analytics queries.
    """

    async def get_overview_kpis(
        self, db: AsyncSession, start_date: date, end_date: date
    ) -> Dict[str, Any]:
        """
        Gathers key performance indicators (KPIs) for the overview in a single, efficient query.
        """
        start_datetime, end_datetime = get_date_range(start_date, end_date)

        # This single query calculates multiple aggregates at once for efficiency.
        stmt = select(
            func.count(AILog.id).label("total_searches"),
            func.count(func.distinct(AILog.user_id)).label("unique_users"),
            func.count(case((AILog.success == True, AILog.id))).label(
                "successful_searches"
            ),
        ).where(AILog.created_at.between(start_datetime, end_datetime))

        result = await db.execute(stmt)
        row = result.one_or_none()

        if not row or row.total_searches == 0:
            return {
                "total_searches": 0,
                "unique_users": 0,
                "success_rate_percent": 0.0,
            }

        success_rate = (row.successful_searches / row.total_searches) * 100

        return {
            "total_searches": row.total_searches,
            "unique_users": row.unique_users,
            "success_rate_percent": success_rate,
        }

    async def get_gender_distribution(
        self, db: AsyncSession, start_date: date, end_date: date
    ) -> List[Dict[str, Any]]:
        """
        Calculates the distribution of searches by user gender.
        """
        start_datetime, end_datetime = get_date_range(start_date, end_date)

        stmt = (
            select(
                User.gender,
                func.count(AILog.id).label("search_count"),
            )
            .join(User, AILog.user_id == User.id)
            .where(
                AILog.created_at.between(start_datetime, end_datetime),
                User.gender.isnot(None),
            )
            .group_by(User.gender)
            .order_by(func.count(AILog.id).desc())
        )

        result = await db.execute(stmt)
        return [
            {"gender": row.gender, "count": row.search_count} for row in result
        ]

    async def get_total_searches(
        self, db: AsyncSession, start_date: date, end_date: date
    ) -> int:
        """Calculates the total number of searches within a date range."""
        start_datetime, end_datetime = get_date_range(start_date, end_date)
        stmt = select(func.count(AILog.id)).where(
            AILog.created_at.between(start_datetime, end_datetime)
        )
        result = await db.execute(stmt)
        return result.scalar_one_or_none() or 0

    async def get_unique_users(
        self, db: AsyncSession, start_date: date, end_date: date
    ) -> int:
        """Calculates the number of unique users within a date range."""
        start_datetime, end_datetime = get_date_range(start_date, end_date)
        stmt = select(func.count(func.distinct(AILog.user_id))).where(
            AILog.created_at.between(start_datetime, end_datetime)
        )
        result = await db.execute(stmt)
        return result.scalar_one_or_none() or 0

    async def get_success_rate(
        self, db: AsyncSession, start_date: date, end_date: date
    ) -> float:
        """Calculates the search success rate within a date range."""
        start_datetime, end_datetime = get_date_range(start_date, end_date)
        stmt = select(
            func.count(AILog.id).filter(AILog.success == True),
            func.count(AILog.id),
        ).where(AILog.created_at.between(start_datetime, end_datetime))

        result = await db.execute(stmt)
        successful_searches, total_searches = result.one_or_none() or (0, 0)

        if total_searches == 0:
            return 0.0
        return (successful_searches / total_searches) * 100

    async def get_searches_trend(
        self, db: AsyncSession, start_date: date, end_date: date
    ) -> Dict[date, int]:
        """Gets the trend of searches per day within a date range."""
        start_datetime, end_datetime = get_date_range(start_date, end_date)
        stmt = (
            select(
                func.date(AILog.created_at).label("search_date"),
                func.count(AILog.id).label("search_count"),
            )
            .where(AILog.created_at.between(start_datetime, end_datetime))
            .group_by("search_date")
            .order_by("search_date")
        )
        result = await db.execute(stmt)
        return {row.search_date: row.search_count for row in result}

    # --- Food Insights ---
    async def get_top_json_items(
        self,
        db: AsyncSession,
        start_date: date,
        end_date: date,
        json_column: str,
        limit: int,
    ) -> List[Dict[str, Any]]:
        """Helper to get top items from a JSON array column."""
        start_datetime, end_datetime = get_date_range(start_date, end_date)

        stmt = text(
            f"""
            SELECT item, COUNT(item) AS count
            FROM ai_logs,
                 JSON_TABLE(
                    {json_column},
                    '$[*]' COLUMNS(item VARCHAR(255) PATH '$')
                 ) AS jt
            WHERE ai_logs.created_at BETWEEN :start_datetime AND :end_datetime
            GROUP BY item
            ORDER BY count DESC
            LIMIT :limit
            """
        )

        result = await db.execute(
            stmt,
            {
                "start_datetime": start_datetime,
                "end_datetime": end_datetime,
                "limit": limit,
            },
        )
        return [{"item": row.item, "count": row.count} for row in result]

    async def get_zero_result_queries(
        self, db: AsyncSession, start_date: date, end_date: date, limit: int
    ) -> List[Dict[str, Any]]:
        """
        Gets the most common queries that resulted in zero results, including a
        breakdown of which countries the queries came from.
        """
        start_datetime, end_datetime = get_date_range(start_date, end_date)

        # Step 1: Find the top N failing queries
        top_queries_stmt = (
            select(AILog.raw_query, func.count(AILog.id).label("attempts"))
            .where(
                AILog.results_found == 0,
                AILog.created_at.between(start_datetime, end_datetime),
            )
            .group_by(AILog.raw_query)
            .order_by(func.count(AILog.id).desc())
            .limit(limit)
        )
        top_queries_result = await db.execute(top_queries_stmt)
        top_queries = top_queries_result.all()

        if not top_queries:
            return []

        # Step 2: For each top query, find the country distribution
        final_results = []
        for query in top_queries:
            country_dist_stmt = (
                select(AILog.country, func.count(AILog.id).label("count"))
                .where(
                    AILog.raw_query == query.raw_query,
                    AILog.results_found == 0,
                    AILog.created_at.between(start_datetime, end_datetime),
                    AILog.country.isnot(None),
                )
                .group_by(AILog.country)
                .order_by(func.count(AILog.id).desc())
            )
            country_dist_result = await db.execute(country_dist_stmt)
            
            country_distribution = []
            for row in country_dist_result.all():
                percentage = (row.count / query.attempts) * 100
                country_distribution.append(
                    {
                        "country": row.country,
                        "count": row.count,
                        "percentage": round(percentage, 2),
                    }
                )
            
            final_results.append({
                "query": query.raw_query,
                "attempts": query.attempts,
                "country_distribution": country_distribution,
            })

        return final_results

    # --- Restaurant Insights ---
    async def get_most_common_restaurants(
        self,
        db: AsyncSession,
        start_date: date,
        end_date: date,
        column_name: str,
        limit: int,
    ) -> List[Dict[str, Any]]:
        """Helper to get most common restaurants from a specified column."""
        start_datetime, end_datetime = get_date_range(start_date, end_date)
        column = getattr(AILog, column_name)

        stmt = (
            select(column.label("restaurant"), func.count(AILog.id).label("count"))
            .where(
                column.isnot(None),
                AILog.created_at.between(start_datetime, end_datetime),
            )
            .group_by(column)
            .order_by(func.count(AILog.id).desc())
            .limit(limit)
        )
        result = await db.execute(stmt)
        return [
            {"restaurant": row.restaurant, "count": row.count} for row in result
        ]

    # --- Geographic Insights ---
    async def get_searches_by_country(
        self, db: AsyncSession, start_date: date, end_date: date, limit: int
    ) -> List[Dict[str, Any]]:
        """Gets the search count grouped by country."""
        start_datetime, end_datetime = get_date_range(start_date, end_date)
        stmt = (
            select(AILog.country, func.count(AILog.id).label("count"))
            .where(
                AILog.country.isnot(None),
                AILog.created_at.between(start_datetime, end_datetime),
            )
            .group_by(AILog.country)
            .order_by(func.count(AILog.id).desc())
            .limit(limit)
        )
        result = await db.execute(stmt)
        return [{"country": row.country, "count": row.count} for row in result]

    async def get_success_rate_by_country(
        self, db: AsyncSession, start_date: date, end_date: date, limit: int
    ) -> List[Dict[str, Any]]:
        """Calculates the success rate for each country."""
        start_datetime, end_datetime = get_date_range(start_date, end_date)
        
        subquery = (
            select(
                AILog.country,
                func.count(AILog.id).label("total_searches"),
                func.sum(cast(AILog.success, Integer)).label("successful_searches"),
            )
            .where(
                AILog.country.isnot(None),
                AILog.created_at.between(start_datetime, end_datetime),
            )
            .group_by(AILog.country)
            .subquery()
        )

        stmt = (
            select(
                subquery.c.country,
                subquery.c.total_searches,
                subquery.c.successful_searches,
                (
                    (subquery.c.successful_searches * 100.0)
                    / subquery.c.total_searches
                ).label("success_rate"),
            )
            .order_by(subquery.c.total_searches.desc())
            .limit(limit)
        )
        
        result = await db.execute(stmt)
        return [
            {
                "country": row.country,
                "total_searches": row.total_searches,
                "success_rate": round(row.success_rate, 2) if row.success_rate is not None else 0,
            }
            for row in result
        ]

def get_analytics_service() -> AnalyticsService:
    """
    Dependency injector for the AnalyticsService.
    """
    return AnalyticsService()
