import logging

from datetime import datetime, timedelta

from flask import current_app

from app_server.extensions import db
from app_server.models import VerificationCode
from app_server.utils.security import generate_otp, hash_otp, normalize_email, normalize_phone, verify_otp

logger = logging.getLogger(__name__)


def _otp_settings():
    return {
        "length": int(current_app.config.get("OTP_LENGTH", 6)),
        "expiry_minutes": int(current_app.config.get("OTP_EXPIRY_MINUTES", 10)),
        "max_attempts": int(current_app.config.get("OTP_MAX_ATTEMPTS", 5)),
        "resend_cooldown": int(current_app.config.get("OTP_RESEND_COOLDOWN_SECONDS", 60)),
    }


def _destination(channel: str, destination: str) -> str:
    if channel == "email":
        return normalize_email(destination)
    if channel == "phone":
        return normalize_phone(destination)
    raise ValueError("Unsupported OTP channel")


def _recent_code(channel: str, destination: str, purpose: str):
    dest = _destination(channel, destination)
    return (
        VerificationCode.query.filter_by(
            channel=channel,
            destination=dest,
            purpose=purpose,
        )
        .order_by(VerificationCode.created_at.desc())
        .first()
    )


def can_resend_otp(channel: str, destination: str, purpose: str) -> tuple[bool, int]:
    settings = _otp_settings()
    latest = _recent_code(channel, destination, purpose)
    if not latest or not latest.created_at:
        return True, 0
    elapsed = (datetime.utcnow() - latest.created_at).total_seconds()
    remaining = int(settings["resend_cooldown"] - elapsed)
    if remaining > 0:
        return False, remaining
    return True, 0


def issue_otp(channel: str, destination: str, purpose: str) -> tuple[str, VerificationCode]:
    settings = _otp_settings()
    dest = _destination(channel, destination)
    allowed, wait_seconds = can_resend_otp(channel, dest, purpose)
    if not allowed:
        raise ValueError(f"Please wait {wait_seconds} seconds before requesting another code.")

    code = generate_otp(settings["length"])
    record = VerificationCode(
        channel=channel,
        destination=dest,
        purpose=purpose,
        code_hash=hash_otp(code),
        expires_at=datetime.utcnow() + timedelta(minutes=settings["expiry_minutes"]),
        attempt_count=0,
        verified_at=None,
    )
    db.session.add(record)
    db.session.commit()
    return code, record


def verify_issued_otp(channel: str, destination: str, purpose: str, code: str) -> tuple[bool, str]:
    settings = _otp_settings()
    dest = _destination(channel, destination)
    record = _recent_code(channel, dest, purpose)
    if not record:
        return False, "No verification code found. Request a new one."
    if record.verified_at:
        return False, "This code was already used. Request a new one."
    if record.expires_at and record.expires_at < datetime.utcnow():
        return False, "Verification code expired. Request a new one."
    if record.attempt_count >= settings["max_attempts"]:
        return False, "Too many failed attempts. Request a new code."

    record.attempt_count += 1
    if not verify_otp(code, record.code_hash):
        db.session.commit()
        remaining = max(settings["max_attempts"] - record.attempt_count, 0)
        if remaining == 0:
            return False, "Too many failed attempts. Request a new code."
        return False, f"Invalid verification code. {remaining} attempt(s) left."

    record.verified_at = datetime.utcnow()
    db.session.commit()
    return True, "Verified"


def is_contact_verified(channel: str, destination: str, purpose: str = "registration") -> bool:
    dest = _destination(channel, destination)
    record = (
        VerificationCode.query.filter_by(
            channel=channel,
            destination=dest,
            purpose=purpose,
        )
        .filter(VerificationCode.verified_at.isnot(None))
        .order_by(VerificationCode.verified_at.desc())
        .first()
    )
    if not record or not record.verified_at:
        return False

    grace_minutes = int(current_app.config.get("REGISTRATION_VERIFICATION_GRACE_MINUTES", 30))
    return datetime.utcnow() - record.verified_at <= timedelta(minutes=grace_minutes)


def send_email_otp(email: str, code: str) -> bool:
    from app_server.utils.notifications import send_email_message

    subject = "SwiftCart verification code"
    body = (
        f"Your SwiftCart email verification code is {code}.\n\n"
        f"It expires in {current_app.config.get('OTP_EXPIRY_MINUTES', 10)} minutes.\n"
        "If you did not request this, you can ignore this email."
    )
    return send_email_message(email, subject, body)


def send_phone_otp(phone: str, code: str) -> tuple[bool, str]:
    """
    Send SMS OTP via Twilio. Returns (True, "sms"|"dev_log"|"fallback") on success,
    (False, "") if delivery is required but nothing worked.
    """
    from app_server.api.routes import send_sms_notification

    message = (
        f"SwiftCart verification code: {code}. "
        f"Expires in {current_app.config.get('OTP_EXPIRY_MINUTES', 10)} minutes."
    )
    if send_sms_notification(phone, message):
        return True, "sms"

    if current_app.config.get("REGISTRATION_OTP_DEV_LOG", False):
        logger.warning("SMS not sent; REGISTRATION_OTP_DEV_LOG enabled. OTP for %s: %s", phone, code)
        print(f"[SwiftCart OTP] SMS to {phone}: {message}")
        return True, "dev_log"

    if current_app.config.get("SMS_OTP_FALLBACK_ON_FAILURE", True):
        logger.warning(
            "Twilio SMS failed for %s (often: trial account → verify number at twilio.com, or upgrade). "
            "OTP is printed below. Set SWIFTCART_SMS_OTP_FALLBACK=0 to disable this fallback.",
            phone,
        )
        print(f"[SwiftCart OTP] SMS delivery failed — use this code for {phone}: {code}")
        print(message)
        return True, "fallback"

    return False, ""
