# dns_ai_module.py

import os
import json
import argparse
import redis
import joblib
import pandas as pd
from sklearn.ensemble import IsolationForest

from c_features import extract_features, tag_anomalies

MODEL_PATH = 'dns_model.pkl'
FEATURE_LIST_PATH = 'dns_model_features.txt'
EVENT_TYPE = 'dns'


# ---------------- Sanitize ----------------
def _sanitize_dns_df(df: pd.DataFrame) -> pd.DataFrame:
    """
    Drop non-scalar columns (dict/list/tuple/set),
    coerce object columns to numeric (drop if cannot be coerced),
    and replace NaNs/inf with 0.
    """
    # 1) drop columns with non-scalar values
    bad_cols = []
    for c in df.columns:
        s = df[c]
        sample = s.head(1000)
        if sample.map(lambda v: isinstance(v, (dict, list, tuple, set))).any():
            bad_cols.append(c)
    if bad_cols:
        df = df.drop(columns=bad_cols)

    # 2) coerce object columns to numeric or drop
    for c in list(df.select_dtypes(include=["object"]).columns):
        numeric = pd.to_numeric(df[c], errors="coerce")
        if numeric.notna().any():
            df[c] = numeric.fillna(0)
        else:
            df = df.drop(columns=[c])

    # 3) final safety
    df = df.replace([float("inf"), float("-inf")], 0).fillna(0)
    # any leftover object → numeric
    for c in df.columns:
        if df[c].dtype == object:
            df[c] = pd.to_numeric(df[c], errors="coerce").fillna(0)
    return df


# ---------------- IO utils ----------------
def load_logs_from_file(filename='dns_log.json'):
    with open(filename, 'r') as f:
        data = [json.loads(line.strip()) for line in f]
    print(f"[+] Loaded {len(data)} events from {filename}")
    return data

def save_logs_to_file(events, filename='dns_log.json'):
    with open(filename, 'w') as f:
        for e in events:
            f.write(json.dumps(e) + "\n")
    print(f"[+] Events saved in {filename}")

def extract_required_columns():
    if os.path.exists(FEATURE_LIST_PATH):
        with open(FEATURE_LIST_PATH) as f:
            return f.read().splitlines()
    return None


# ---------------- Inference ----------------
def analyze_and_tag(events, model):
    # 1) Columns list from training (if exists)
    required_cols = extract_required_columns()

    # 2) Extract features via C++ (respect local schema if present, else None)
    features = extract_features(EVENT_TYPE, events, required_cols)

    if not features:
        print("[!] Empty list of features from C++")
        return [], events

    # 3) Build DataFrame and sanitize
    df = pd.DataFrame(features)
    df = _sanitize_dns_df(df)

    # 4) Align by local schema (if present)
    if required_cols:
        for col in required_cols:
            if col not in df.columns:
                df[col] = 0
        df = df[required_cols]

    # 5) Strict alignment to model schema (feature_names_in_) if available
    expected = getattr(model, "feature_names_in_", None)
    if expected is not None:
        for col in expected:
            if col not in df.columns:
                df[col] = 0
        df = df.loc[:, list(expected)]
    elif required_cols:
        # fallback: enforce order/columns from required_cols
        df = df.loc[:, required_cols]

    # 6) Safety
    df = df.fillna(0)
    for c in df.columns:
        if df[c].dtype == object:
            df[c] = pd.to_numeric(df[c], errors="coerce").fillna(0)

    # 7) Prediction + tagging
    preds = model.predict(df)
    tagged = tag_anomalies(events, preds.tolist())

    anomalies = [e for e in tagged if e.get("anomaly")]
    normals   = [e for e in tagged if not e.get("anomaly")]

    print(f"[+] Total events: {len(events)}")
    print(f"[+] Anomalies: {len(anomalies)} | Normal: {len(normals)}")
    return anomalies, normals


# ---------------- CLI / Training ----------------
def main(args=None):
    if args is None:
        parser = argparse.ArgumentParser(description='AI DNS Suricata Events Analyzer')
        parser.add_argument('--retrain', action='store_true')
        parser.add_argument('--analyze', action='store_true')
        parser.add_argument('--file', type=str)
        parser.add_argument('--live', action='store_true')
        args = parser.parse_args()

    # Load events
    if args.file:
        events = load_logs_from_file(args.file)
    else:
        r = redis.Redis(host='localhost', port=6379, decode_responses=True)
        raw = r.lrange(EVENT_TYPE, 0, -1)
        events = [json.loads(x) for x in raw]
        print(f"[+] Loaded {len(events)} events from Redis")

    # Extract all features (no forced schema) and sanitize
    features = extract_features(EVENT_TYPE, events, None)
    df = pd.DataFrame(features)
    df = _sanitize_dns_df(df)

    # Training
    if args.retrain or not os.path.exists(MODEL_PATH):
        model = IsolationForest(n_estimators=100, contamination=0.05, random_state=42)
        model.fit(df)
        joblib.dump(model, MODEL_PATH)

        # Persist training schema AFTER sanitize
        with open(FEATURE_LIST_PATH, "w") as f:
            f.write("\n".join(df.columns))
        print(f"[+] The model is trained and the features are saved in {FEATURE_LIST_PATH}")
    else:
        model = joblib.load(MODEL_PATH)

    # Inference / analysis
    if args.analyze:
        anomalies, normals = analyze_and_tag(events, model)
        save_logs_to_file(anomalies, 'dns_anomalies.json')

        df_anom = pd.DataFrame([
            {"src_ip": e.get("src_ip"), "dest_ip": e.get("dest_ip")}
            for e in anomalies
        ])

        print("\n=== Top 10 Sources of Anomalous DNS Traffic:")
        if not df_anom.empty and "src_ip" in df_anom:
            print(df_anom['src_ip'].value_counts().head(10))
        else:
            print("No data")

        print("\n=== Top 10 Recipients of Anomalous DNS Traffic:")
        if not df_anom.empty and "dest_ip" in df_anom:
            print(df_anom['dest_ip'].value_counts().head(10))
        else:
            print("No data")


if __name__ == "__main__":
    main()
