from datetime import datetime, timedelta

from sqlalchemy import func

from app_server.extensions import db
from app_server.models import Expense, Order, OrderItem, PosSyncLog

ADMIN_REPORT_TYPES = (
    ("sales_by_date", "Sales by Date"),
    ("sales_by_product", "Sales by Product"),
    ("expenses", "Expense Reports"),
    ("daily_totals", "Daily totals (Web vs POS)"),
    ("pos_sync_sales", "POS synced sales"),
)


def parse_report_date(value, fallback):
    if not value:
        return fallback
    try:
        return datetime.strptime(str(value)[:10], "%Y-%m-%d").date()
    except ValueError:
        return fallback


def _orders_in_range(range_start, range_end_excl, *, source=None):
    query = Order.query.filter(
        Order.date >= range_start,
        Order.date < range_end_excl,
        Order.status != "Refunded",
    )
    if source:
        query = query.filter(Order.source == source)
    return query.order_by(Order.date.desc()).all()


def _build_daily_totals(date_from, date_to):
    report_daily_rows = []
    chart_labels = []
    chart_web = []
    chart_pos = []
    summary_web = summary_pos = summary_total = 0.0
    summary_web_orders = summary_pos_orders = 0
    summary_expenses = 0.0

    d = date_from
    while d <= date_to:
        day_start = datetime.combine(d, datetime.min.time())
        day_end = day_start + timedelta(days=1)

        total_sales = float(
            db.session.query(func.coalesce(func.sum(Order.final_total), 0.0))
            .filter(
                Order.date >= day_start,
                Order.date < day_end,
                Order.status != "Refunded",
            )
            .scalar()
            or 0
        )
        web_sales = float(
            db.session.query(func.coalesce(func.sum(Order.final_total), 0.0))
            .filter(
                Order.date >= day_start,
                Order.date < day_end,
                Order.source == "Web",
                Order.status != "Refunded",
            )
            .scalar()
            or 0
        )
        pos_sales = max(total_sales - web_sales, 0.0)
        web_orders = Order.query.filter(
            Order.date >= day_start,
            Order.date < day_end,
            Order.source == "Web",
            Order.status != "Refunded",
        ).count()
        total_orders = Order.query.filter(
            Order.date >= day_start,
            Order.date < day_end,
            Order.status != "Refunded",
        ).count()
        pos_orders = max(total_orders - web_orders, 0)
        expense_day = float(
            db.session.query(func.coalesce(func.sum(Expense.amount), 0.0))
            .filter(Expense.date == d)
            .scalar()
            or 0
        )

        report_daily_rows.append(
            {
                "date_label": d.strftime("%a, %d %b %Y"),
                "web_sales": web_sales,
                "pos_sales": pos_sales,
                "total_sales": total_sales,
                "web_orders": web_orders,
                "pos_orders": pos_orders,
                "expenses": expense_day,
            }
        )
        chart_labels.append(d.strftime("%d %b"))
        chart_web.append(web_sales)
        chart_pos.append(pos_sales)
        summary_web += web_sales
        summary_pos += pos_sales
        summary_total += total_sales
        summary_web_orders += web_orders
        summary_pos_orders += pos_orders
        summary_expenses += expense_day
        d += timedelta(days=1)

    columns = ["Date", "Web sales", "POS sales", "Total", "Orders W / P", "Expenses"]
    rows = []
    for row in report_daily_rows:
        rows.append(
            {
                "cells": [
                    row["date_label"],
                    row["web_sales"],
                    row["pos_sales"],
                    row["total_sales"],
                    f"{row['web_orders']} / {row['pos_orders']}",
                    row["expenses"],
                ],
                "numeric": [False, True, True, True, False, True],
            }
        )
    if rows:
        rows.append(
            {
                "cells": [
                    "Total",
                    summary_web,
                    summary_pos,
                    summary_total,
                    f"{summary_web_orders} / {summary_pos_orders}",
                    summary_expenses,
                ],
                "numeric": [False, True, True, True, False, True],
                "footer": True,
            }
        )

    return {
        "columns": columns,
        "rows": rows,
        "summary": {
            "total_amount": summary_total,
            "total_discount": 0.0,
            "total_after_discount": summary_total,
            "total_profit": 0.0,
            "total_records": summary_web_orders + summary_pos_orders,
        },
        "report_daily_rows": report_daily_rows,
        "report_chart_labels": chart_labels,
        "report_chart_web": chart_web,
        "report_chart_pos": chart_pos,
        "show_chart": True,
    }


def _build_sales_by_date(range_start, range_end_excl):
    orders = _orders_in_range(range_start, range_end_excl)
    columns = ["Channel", "Order ID", "Date", "Customer", "Total", "Discount", "Payment", "Status"]
    rows = []
    total_amount = total_discount = 0.0

    for order in orders:
        channel = "Web" if (order.source or "") == "Web" else "POS"
        amount = float(order.final_total or 0)
        discount = float(order.discount or 0)
        total_amount += amount
        total_discount += discount
        rows.append(
            {
                "cells": [
                    channel,
                    order.id,
                    order.date.strftime("%Y-%m-%d %H:%M") if order.date else "—",
                    order.customer_name or "Walk-in",
                    amount,
                    discount,
                    order.payment_method or "—",
                    order.status or "—",
                ],
                "numeric": [False, False, False, False, True, True, False, False],
            }
        )

    return {
        "columns": columns,
        "rows": rows,
        "summary": {
            "total_amount": total_amount,
            "total_discount": total_discount,
            "total_after_discount": max(total_amount, 0.0),
            "total_profit": 0.0,
            "total_records": len(rows),
        },
        "show_chart": False,
    }


def _build_sales_by_product(range_start, range_end_excl):
    query = (
        db.session.query(
            OrderItem.product_name,
            func.sum(OrderItem.quantity).label("total_qty"),
            func.sum(OrderItem.price * OrderItem.quantity).label("total_amount"),
        )
        .join(Order, OrderItem.order_id == Order.id)
        .filter(
            Order.status != "Refunded",
            Order.date >= range_start,
            Order.date < range_end_excl,
        )
        .group_by(OrderItem.product_name)
        .order_by(func.sum(OrderItem.quantity).desc())
        .all()
    )
    columns = ["Product", "Quantity", "Amount"]
    rows = []
    total_amount = 0.0
    total_qty = 0
    for row in query:
        amount = float(row.total_amount or 0)
        qty = int(row.total_qty or 0)
        total_amount += amount
        total_qty += qty
        rows.append(
            {
                "cells": [row.product_name, qty, amount],
                "numeric": [False, True, True],
            }
        )

    return {
        "columns": columns,
        "rows": rows,
        "summary": {
            "total_amount": total_amount,
            "total_discount": 0.0,
            "total_after_discount": total_amount,
            "total_profit": 0.0,
            "total_records": total_qty,
        },
        "show_chart": False,
    }


def _build_expenses(date_from, date_to):
    expenses = (
        Expense.query.filter(Expense.date >= date_from, Expense.date <= date_to)
        .order_by(Expense.date.desc(), Expense.id.desc())
        .all()
    )
    columns = ["Date", "Title", "Category", "Amount", "Method"]
    rows = []
    total_amount = 0.0
    for expense in expenses:
        amount = float(expense.amount or 0)
        total_amount += amount
        rows.append(
            {
                "cells": [
                    expense.date.strftime("%Y-%m-%d") if expense.date else "—",
                    expense.title,
                    expense.category or "—",
                    amount,
                    expense.method or "—",
                ],
                "numeric": [False, False, False, True, False],
            }
        )

    return {
        "columns": columns,
        "rows": rows,
        "summary": {
            "total_amount": total_amount,
            "total_discount": 0.0,
            "total_after_discount": total_amount,
            "total_profit": 0.0,
            "total_records": len(rows),
        },
        "show_chart": False,
    }


def _build_pos_sync_sales(range_start, range_end_excl):
    orders = _orders_in_range(range_start, range_end_excl, source="POS")
    columns = ["Order ID", "Date", "Customer", "Payment", "Amount", "Status"]
    rows = []
    total_amount = 0.0
    for order in orders:
        amount = float(order.final_total or 0)
        total_amount += amount
        rows.append(
            {
                "cells": [
                    order.id,
                    order.date.strftime("%Y-%m-%d %H:%M") if order.date else "—",
                    order.customer_name or "Walk-in",
                    order.payment_method or "—",
                    amount,
                    order.status or "—",
                ],
                "numeric": [False, False, False, False, True, False],
            }
        )

    return {
        "columns": columns,
        "rows": rows,
        "summary": {
            "total_amount": total_amount,
            "total_discount": 0.0,
            "total_after_discount": total_amount,
            "total_profit": 0.0,
            "total_records": len(rows),
        },
        "show_chart": False,
    }


def build_reports_page_context(req):
    today = datetime.utcnow().date()
    default_from = today - timedelta(days=30)
    date_to = parse_report_date(req.args.get("to"), today)
    date_from = parse_report_date(req.args.get("from"), default_from)
    if date_from > date_to:
        date_from, date_to = date_to, date_from
    if (date_to - date_from).days > 90:
        date_from = date_to - timedelta(days=90)

    report_type = (req.args.get("type") or "sales_by_date").strip()
    valid_types = {key for key, _ in ADMIN_REPORT_TYPES}
    if report_type not in valid_types:
        report_type = "sales_by_date"
    report_generated = req.args.get("generate") == "1"

    range_start = datetime.combine(date_from, datetime.min.time())
    range_end_excl = datetime.combine(date_to + timedelta(days=1), datetime.min.time())
    today_anchor = datetime.utcnow().date()
    report_quick_links = [
        {
            "label": "7 days",
            "from": (today_anchor - timedelta(days=6)).isoformat(),
            "to": today_anchor.isoformat(),
        },
        {
            "label": "30 days",
            "from": (today_anchor - timedelta(days=29)).isoformat(),
            "to": today_anchor.isoformat(),
        },
        {
            "label": "90 days",
            "from": (today_anchor - timedelta(days=89)).isoformat(),
            "to": today_anchor.isoformat(),
        },
    ]

    report_label = dict(ADMIN_REPORT_TYPES).get(report_type, "Sales by Date")
    empty_summary = {
        "total_amount": 0.0,
        "total_discount": 0.0,
        "total_after_discount": 0.0,
        "total_profit": 0.0,
        "total_records": 0,
    }
    context = {
        "report_types": ADMIN_REPORT_TYPES,
        "report_type": report_type,
        "report_label": report_label,
        "report_generated": report_generated,
        "report_date_from": date_from,
        "report_date_to": date_to,
        "report_date_from_iso": date_from.isoformat(),
        "report_date_to_iso": date_to.isoformat(),
        "report_quick_links": report_quick_links,
        "report_columns": [],
        "report_rows": [],
        "report_summary": empty_summary,
        "report_show_chart": False,
        "report_chart_labels": [],
        "report_chart_web": [],
        "report_chart_pos": [],
        "pos_sync_logs": PosSyncLog.query.order_by(PosSyncLog.synced_at.desc()).limit(10).all(),
    }

    if not report_generated:
        return context

    if report_type == "sales_by_date":
        payload = _build_sales_by_date(range_start, range_end_excl)
    elif report_type == "sales_by_product":
        payload = _build_sales_by_product(range_start, range_end_excl)
    elif report_type == "expenses":
        payload = _build_expenses(date_from, date_to)
    elif report_type == "daily_totals":
        payload = _build_daily_totals(date_from, date_to)
    else:
        payload = _build_pos_sync_sales(range_start, range_end_excl)

    context.update(
        {
            "report_columns": payload["columns"],
            "report_rows": payload["rows"],
            "report_summary": payload["summary"],
            "report_show_chart": payload.get("show_chart", False),
            "report_chart_labels": payload.get("report_chart_labels", []),
            "report_chart_web": payload.get("report_chart_web", []),
            "report_chart_pos": payload.get("report_chart_pos", []),
        }
    )
    return context
