"""
ai_rest_api.py — Complete AI REST API for Flow, HTTP, DNS, TLS, and Meta analysis

Author: Sergey Filipovich
Email: sergey.filipovich@suri-oculus.com

All rights reserved © Sergey Filipovich.
This code is proprietary software and must not be copied, modified, or distributed 
without explicit written permission from the author.
"""
from fastapi import FastAPI, Query
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Dict, Any
from fastapi.responses import JSONResponse
from integrated_model import extract_meta_features, load_model
import joblib
import pandas as pd
import os
import redis
import json
import numpy as np
from datetime import datetime, timedelta

#from http_ai_module import extract_http_features, analyze_and_tag as analyze_http
from dns_ai_module import analyze_and_tag as analyze_dns
from tls_ai_module import analyze_and_tag as analyze_tls

from flow_ai_module import extract_features as extract_flow_features, analyze_and_tag as analyze_flow
from http_ai_module import analyze_and_tag as analyze_http
app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

MODEL_PATHS = {
    "http": "http_model.pkl",
    "dns": "dns_model.pkl",
    "flow": "flow_model.pkl",
    "tls": "tls_model.pkl",
    "meta": "meta_model.pkl"
}
META_FEATURES_FILE = "meta_model_features.txt"  # ← NEW
from c_features import extract_features

EXTRACTORS = {
    "flow": lambda events, cols=None: extract_features("flow", events, cols),
    "dns": lambda events, cols=None: extract_features("dns", events, cols),
    "http": lambda events, cols=None: extract_features("http", events, cols),
    "tls": lambda events, cols=None: extract_features("tls", events, cols)
}
def analyze_meta(_, model):
    import os
    import pandas as pd
    r = redis.Redis(host="localhost", port=6379, decode_responses=True)

    # 1) Calculate a snapshot of meta-features (one line)
    df = extract_meta_features(r, minutes=60, shift=0)
    if df.empty:
        return [], []

    # 2) Fallback diagram (if needed)
    fallback_cols = None
    if os.path.exists(META_FEATURES_FILE):
        with open(META_FEATURES_FILE) as f:
            fallback_cols = [line.strip() for line in f if line.strip()]

    #3) Hard alignment to the trained scheme
    expected = getattr(model, "feature_names_in_", None)
    if expected is not None:
        for col in expected:
            if col not in df.columns:
                df[col] = 0
        df = df.loc[:, list(expected)]
    elif fallback_cols:
        for col in fallback_cols:
            if col not in df.columns:
                df[col] = 0
        df = df.loc[:, fallback_cols]

    # 4) Insurance by type/NaN
    df = df.fillna(0)
    for c in df.columns:
        if df[c].dtype == object:
            df[c] = pd.to_numeric(df[c], errors="coerce").fillna(0)

    #5) Prediction
    prediction = model.predict(df)[0]
    is_anomaly = (prediction == -1)

    if is_anomaly:
        return [{"meta": "anomaly", **df.to_dict(orient="records")[0]}], []
    else:
        return [], [{"meta": "normal", **df.to_dict(orient="records")[0]}]


ANALYZERS = {
    "http": analyze_http,
    "dns": analyze_dns,
    "flow": analyze_flow,
    "tls": analyze_tls,
    #"events": lambda data, model: ([], {}) , # stub
    "meta": analyze_meta  # ← safe now
}
def filter_by_time(events, start, end):
    return [
        json.loads(e)
        for e in events
        if start <= datetime.fromisoformat(json.loads(e).get("timestamp", "2100-01-01")).replace(tzinfo=None) < end
    ]

def group_by_hour(events):
    hours = {}
    for e in events:
        ts = e.get("timestamp")
        if ts:
            try:
                hour = datetime.fromisoformat(ts).hour
                hours[hour] = hours.get(hour, 0) + 1
            except Exception as ex:
                print("⚠️ Error parsing timestamp:", ts, ex)
    return hours

class EventList(BaseModel):
    events: List[Dict[str, Any]]

@app.get("/")
def root():
    return {"message": "Suri-Oculus AI REST API is running"}

@app.post("/ai/analyze/{event_type}")
def analyze_event_type(event_type: str, request: EventList):
    event_type = event_type.lower()
    if event_type not in MODEL_PATHS:
        return {"error": f"Unsupported event type: {event_type}"}

    model_path = MODEL_PATHS[event_type]
    analyzer = ANALYZERS[event_type]

    if not os.path.exists(model_path):
        return {"error": f"Model for '{event_type}' not found"}

    model = joblib.load(model_path)
    events = request.events
    try:
        anomalies, normals = analyzer(events, model)
        return {
            "total": len(events),
            "anomaly_count": len(anomalies),
            "anomalies": anomalies[:10]
        }
    except Exception as e:
        return {"error": f"Analysis failed: {str(e)}"}

# Universal analyzer from Redis
@app.get("/api/analyze/{event_type}")
def analyze_from_redis(event_type: str, limit: int = 100, offset: int = 0, only_anomalies: bool = False):
    if event_type not in MODEL_PATHS:
        return JSONResponse(status_code=400, content={"error": f"Unsupported event type: {event_type}"})

    redis_client = redis.Redis(host="localhost", port=6379, decode_responses=True)
    try:
        total_available = redis_client.llen(event_type)
        raw_events = redis_client.lrange(event_type, offset, offset + limit - 1)
        parsed_events = []
        for i, item in enumerate(raw_events):
            try:
                event = json.loads(item)
                event["redis_index"] = offset + i
                parsed_events.append(event)
            except json.JSONDecodeError:
                continue

        if not parsed_events:
            return {"total": total_available, "anomaly_count": 0, "events": []}

        model_path = MODEL_PATHS[event_type]
        model = joblib.load(model_path)
        anomalies, normals = ANALYZERS[event_type](parsed_events, model)

        combined = anomalies + normals
        results = []
        for ev in combined:
            if only_anomalies and not ev.get("anomaly"):
                continue
            results.append({
                "index": ev.get("redis_index"),
                "src_ip": ev.get("src_ip"),
                "dst_ip": ev.get("dest_ip"),
                "anomaly": bool(ev.get("anomaly", False)),
                "full": clean_event(ev)
            })

        return {
            "total": total_available,
            "anomaly_count": len(anomalies),
            "events": results
        }
    except Exception as e:
        import traceback
        traceback.print_exc()
        return JSONResponse(status_code=500, content={"error": f"Redis or analysis failed: {str(e)}"})

# Universal timeline
@app.get("/api/analyze_{event_type}_timeline")
def timeline(event_type: str, limit: int = 100, offset: int = 0, interval: str = Query("minute", regex="^(minute|hour|day)$")):
    if event_type not in MODEL_PATHS:
        return {"error": f"Unsupported event type: {event_type}"}
    redis_client = redis.Redis(host="localhost", port=6379, decode_responses=True)
    try:
        raw_events = redis_client.lrange(event_type, offset, offset + limit - 1)
        parsed = [json.loads(x) for x in raw_events if x.strip()]
        model = joblib.load(MODEL_PATHS[event_type])
        anomalies, _ = ANALYZERS[event_type](parsed, model)

        timestamps = [pd.to_datetime(e["timestamp"]) for e in anomalies if e.get("timestamp")]
        if not timestamps:
            return {"interval": interval, "timeline": []}

        df = pd.DataFrame({"ts": timestamps})
        df["bucket"] = df["ts"].dt.floor({"minute": "T", "hour": "h", "day": "d"}[interval])
        timeline = df.groupby("bucket").size().reset_index(name="anomalies")

        return {
            "interval": interval,
            "timeline": [
                {"time": row["bucket"].strftime("%Y-%m-%d %H:%M"), "anomalies": row["anomalies"]}
                for _, row in timeline.iterrows()
            ]
        }
    except Exception as e:
        return {"error": f"Timeline failed: {str(e)}"}

# Universal anomaly comparison (hour-based)
@app.get("/api/anomaly_{event_type}_compare")
def anomaly_compare(event_type: str, from_hour: int = 12, to_hour: int = 14):
    try:
        redis_client = redis.Redis(host="localhost", port=6379, decode_responses=True)
        all_data = redis_client.lrange(event_type, 0, -1)
        now = datetime.now()

        today_range = (now.replace(hour=from_hour, minute=0), now.replace(hour=to_hour, minute=0))
        yest_range = (today_range[0] - timedelta(days=1), today_range[1] - timedelta(days=1))

        today = filter_by_time(all_data, *today_range)
        
        yest = filter_by_time(all_data, *yest_range)
        
        model = joblib.load(MODEL_PATHS[event_type])
        today_anom, _ = ANALYZERS[event_type](today, model)
        yest_anom, _ = ANALYZERS[event_type](yest, model)

        return {
            "interval": f"{from_hour:02d}:00–{to_hour:02d}:00",
            "today": {"total": len(today), "anomalies": len(today_anom)},
            "yesterday": {"total": len(yest), "anomalies": len(yest_anom)}
        }
    except Exception as e:
        return {"error": f"Anomaly compare failed: {str(e)}"}

@app.get("/api/anomaly_{event_type}_compare_timeline")
def anomaly_compare_timeline(event_type: str, from_hour: int = 12, to_hour: int = 14):
    try:
        redis_client = redis.Redis(host="localhost", port=6379, decode_responses=True)
        all_data = redis_client.lrange(event_type, 0, -1)
        now = datetime.now()

        today_range = (now.replace(hour=from_hour, minute=0), now.replace(hour=to_hour, minute=0))
        yest_range = (today_range[0] - timedelta(days=1), today_range[1] - timedelta(days=1))

        today = filter_by_time(all_data, *today_range)
        yest = filter_by_time(all_data, *yest_range)

        model = joblib.load(MODEL_PATHS[event_type])
        today_anom, _ = ANALYZERS[event_type](today, model)
        yest_anom, _ = ANALYZERS[event_type](yest, model)

        today_by = group_by_hour(today_anom)
        yest_by = group_by_hour(yest_anom)

        result = []
        for h in range(from_hour, to_hour):
            label = f"{h:02d}:00"
            result.append({
                "hour": label,
                "today": today_by.get(h, 0),
                "yesterday": yest_by.get(h, 0)
        })


        return {
            "interval": f"{from_hour:02d}:00–{to_hour:02d}:00",
            "timeline": result
        }
    except Exception as e:
        return {"error": f"Anomaly timeline compare failed: {str(e)}"}



def clean_event(ev):
    if isinstance(ev, dict):
        return {k: clean_event(v) for k, v in ev.items()}
    elif isinstance(ev, list):
        return [clean_event(i) for i in ev]
    elif isinstance(ev, (np.bool_, bool)):
        return bool(ev)
    elif isinstance(ev, (np.integer, int)):
        return int(ev)
    elif isinstance(ev, (np.floating, float)):
        return float(ev)
    else:
        return ev

from fastapi import FastAPI, Query
from typing import Optional

@app.get("/api/anomaly_timeline_all")
def anomaly_timeline_all(
    interval: str = Query("minute", regex="^(minute|hour|day)$"),
    minutes: int = 60,
    from_time: Optional[str] = None,
    min_count: int = 0
):
    try:
        redis_client = redis.Redis(host="localhost", port=6379, decode_responses=True)

        now = datetime.now()
        if from_time:
            start_time = datetime.fromisoformat(from_time)
        else:
            start_time = now - timedelta(minutes=minutes)

        timeline_data = {}

        for event_type in MODEL_PATHS:
            raw_events = redis_client.lrange(event_type, 0, -1)
            parsed = [json.loads(x) for x in raw_events if x.strip()]
            recent_events = [
                 ev for ev in parsed
                 if "timestamp" in ev and datetime.fromisoformat(ev["timestamp"]).replace(tzinfo=None) >= start_time
            ]


            model = joblib.load(MODEL_PATHS[event_type])
            anomalies, _ = ANALYZERS[event_type](recent_events, model)

            timestamps = [
                pd.to_datetime(ev["timestamp"]) for ev in anomalies if "timestamp" in ev
            ]

            if not timestamps:
                continue

            df = pd.DataFrame({"ts": timestamps})
            df["bucket"] = df["ts"].dt.floor({"minute": "T", "hour": "H", "day": "D"}[interval])
            grouped = df.groupby("bucket").size().reset_index(name=event_type)

            for _, row in grouped.iterrows():
                key = row["bucket"].strftime("%Y-%m-%d %H:%M")
                if key not in timeline_data:
                    timeline_data[key] = {}
                timeline_data[key][event_type] = row[event_type]

        # Final timeline assembly with filtering by min_count
        timeline = []
        for key in sorted(timeline_data.keys()):
            entry = {"time": key}
            total = 0
            for etype in MODEL_PATHS:
                count = timeline_data[key].get(etype, 0)
                entry[etype] = count
                total += count
            if total >= min_count:
                timeline.append(entry)

        return {
            "interval": interval,
            "timeline": timeline
        }

    except Exception as e:
        return {"error": f"Timeline all failed: {str(e)}"}

from fastapi import Query
from typing import Optional

@app.get("/api/anomaly_timeline_by_ip")
def anomaly_timeline_by_ip(
    ip: str,
    interval: str = Query("minute", regex="^(minute|hour|day)$"),
    minutes: int = 60
):
    try:
        redis_client = redis.Redis(host="localhost", port=6379, decode_responses=True)
        now = datetime.now()
        start_time = now - timedelta(minutes=minutes)

        timeline = {}

        for event_type in MODEL_PATHS:
            raw_events = redis_client.lrange(event_type, 0, -1)
            parsed = [json.loads(x) for x in raw_events if x.strip()]
            recent = [
               e for e in parsed
               if "timestamp" in e and datetime.fromisoformat(e["timestamp"]).replace(tzinfo=None) >= start_time
            ]


            model = joblib.load(MODEL_PATHS[event_type])
            anomalies, _ = ANALYZERS[event_type](recent, model)

            for ev in anomalies:
                ts = ev.get("timestamp")
                if not ts:
                    continue

                if ev.get("src_ip") != ip and ev.get("dest_ip") != ip:
                    continue

                bucket = pd.to_datetime(ts).floor({"minute": "T", "hour": "H", "day": "D"}[interval])
                label = bucket.strftime("%Y-%m-%d %H:%M")
                timeline[label] = timeline.get(label, 0) + 1

        timeline_sorted = [
            {"time": k, "count": timeline[k]}
            for k in sorted(timeline.keys())
        ]

        return {
            "interval": interval,
            "ip": ip,
            "timeline": timeline_sorted
        }

    except Exception as e:
        return {"error": f"IP timeline failed: {str(e)}"}

from fastapi import Query

@app.get("/api/anomaly_correlation_matrix")
def anomaly_correlation_matrix(
    interval: str = Query("minute", regex="^(minute|hour|day)$"),
    minutes: int = Query(180, gt=0, le=1440)
):
    try:
        redis_client = redis.Redis(host="localhost", port=6379, decode_responses=True)
        now = datetime.now()
        start_time = now - timedelta(minutes=minutes)

        event_types = ["flow", "http", "dns", "tls"]
        series_data = {}

        for event_type in event_types:
            raw = redis_client.lrange(event_type, 0, -1)
            parsed = [json.loads(x) for x in raw if x.strip()]
            recent = [
             e for e in parsed
             if "timestamp" in e and datetime.fromisoformat(e["timestamp"]).replace(tzinfo=None) >= start_time
            ]


            model = joblib.load(MODEL_PATHS[event_type])
            anomalies, _ = ANALYZERS[event_type](recent, model)

            timestamps = [pd.to_datetime(e["timestamp"]) for e in anomalies if e.get("timestamp")]
            if not timestamps:
                continue

            df = pd.DataFrame({"ts": timestamps})
            df["bucket"] = df["ts"].dt.floor({"minute": "T", "hour": "H", "day": "D"}[interval])
            grouped = df.groupby("bucket").size().reset_index(name="count")
            grouped.set_index("bucket", inplace=True)
            series_data[event_type] = grouped["count"]

        # Merge by time, replace NaN with 0
        df_all = pd.DataFrame(series_data).fillna(0)

        # Calculate Pearson correlation matrix
        corr = df_all.corr(method="pearson").fillna(0)

        return {
             "types": corr.columns.tolist(),
             "matrix": corr.round(3).values.tolist()
       }


    except Exception as e:
        return {"error": f"Correlation matrix failed: {str(e)}"}

from pydantic import BaseModel

class MetaCheckRequest(BaseModel):
    window_size: int = 60
    shift: int = 0

class MetaCheckResponse(BaseModel):
    is_anomaly: bool
    features: dict

@app.post("/api/meta/check", response_model=MetaCheckResponse)
def check_meta(request: MetaCheckRequest):
    try:
        r = redis.Redis(host="localhost", port=6379, decode_responses=True)
        features = extract_meta_features(r, request.window_size, request.shift)

        if not os.path.exists("meta_model.pkl"):
            return {"is_anomaly": False, "features": features}

        model = load_model("meta_model.pkl")
        df = pd.DataFrame([features])
        prediction = model.predict(df)[0]
        is_anomaly = prediction == -1

        return {"is_anomaly": is_anomaly, "features": features}

    except Exception as e:
        return {"is_anomaly": False, "features": {}, "error": str(e)}

from integrated_model import extract_meta_features, load_model

from integrated_model import detect_multi_anomalies

@app.get("/api/multianomaly_timeline")
def multianomaly_timeline(interval: str = "minute", minutes: int = 120):
    try:
        if interval not in ["minute", "hour", "day"]:
            return {"error": "Invalid interval. Use minute, hour, or day."}

        redis_conn = redis.Redis(host="localhost", port=6379, decode_responses=True)
        timeline = detect_multi_anomalies(redis_conn, interval=interval, minutes=minutes)

        return {
            "interval": interval,
            "minutes": minutes,
            "timeline": timeline
        }
    except Exception as e:
        return {"error": f"Multianomaly timeline failed: {str(e)}"}

# Add to ai_rest_api.py

from fastapi import Query

@app.get("/api/multianomalies_log")
def get_multianomalies_log(limit: int = 100, only_unescalated: bool = False):
    try:
        redis_conn = redis.Redis(host="localhost", port=6379, decode_responses=True)
        raw = redis_conn.lrange("multianomalies", 0, limit - 1)

        results = []
        for row in raw:
            try:
                obj = json.loads(row)
                if only_unescalated and obj.get("escalated"):
                    continue
                results.append(obj)
            except Exception:
                continue

        return {"count": len(results), "items": results}
    except Exception as e:
        return {"error": str(e)}

from fastapi import Body

@app.post("/api/escalate_multianomaly")
def escalate_multianomaly(time: str = Body(..., embed=True)):
    try:
        redis_conn = redis.Redis(host="localhost", port=6379, decode_responses=True)
        entries = redis_conn.lrange("multianomalies", 0, -1)
        updated = False

        for i, raw in enumerate(entries):
            try:
                obj = json.loads(raw)
                if obj.get("time") == time:
                    obj["escalated"] = True
                    redis_conn.lset("multianomalies", i, json.dumps(obj))
                    updated = True
                    break
            except Exception:
                continue

        if updated:
            return {"status": "ok"}
        else:
            return {"error": "Multi-anomaly with such time not found"}
    except Exception as e:
        return {"error": str(e)}

import smtplib
from email.message import EmailMessage
from config_email import EMAIL_HOST, EMAIL_PORT, EMAIL_USERNAME, EMAIL_PASSWORD, EMAIL_FROM, EMAIL_TO

def send_escalation_email(anomaly):
    try:
        msg = EmailMessage()
        msg["Subject"] = f"🚨 Multianomaly ascalated: {anomaly['time']}"
        msg["From"] = EMAIL_FROM
        msg["To"] = ", ".join(EMAIL_TO)

        body = f"""🔥 Multianomaly ascalated

Time: {anomaly['time']}
Stream types: {", ".join(anomaly['types'])}
Level: {anomaly['level']}

Escalated: True
"""

        msg.set_content(body)

        with smtplib.SMTP(EMAIL_HOST, EMAIL_PORT) as server:
            server.starttls()
            server.login(EMAIL_USERNAME, EMAIL_PASSWORD)
            server.send_message(msg)

        print(f"[+] Letter was sent to: {EMAIL_TO}")
    except Exception as e:
        print(f"[!] Error sending letter: {e}")
#from yourmodule import send_escalation_email  

@app.post("/api/escalate_multianomaly")
def escalate_multianomaly(time: str = Body(..., embed=True)):
    try:
        redis_conn = redis.Redis(host="localhost", port=6379, decode_responses=True)
        entries = redis_conn.lrange("multianomalies", 0, -1)
        updated = False

        for i, raw in enumerate(entries):
            try:
                obj = json.loads(raw)
                if obj.get("time") == time:
                    obj["escalated"] = True
                    redis_conn.lset("multianomalies", i, json.dumps(obj))
                    send_escalation_email(obj)  # 🚀 sending!
                    updated = True
                    break
            except Exception:
                continue

        return {"status": "ok"} if updated else {"error": "Multi-anomaly with such time not found"}
    except Exception as e:
        return {"error": str(e)}
