# tls_ai_module.py
import os
import json
import pandas as pd
import joblib
import redis
import argparse
from sklearn.ensemble import IsolationForest
from c_features import extract_features

MODEL_PATH = "tls_model.pkl"
FEATURES_FILE = "tls_model_features.txt"
EVENT_TYPE = "tls"


# ---------------- Sanitize ----------------
def _sanitize_tls_df(df: pd.DataFrame) -> pd.DataFrame:
    """
    Drop non-scalar columns (dict/list/tuple/set),
    coerce object columns to numeric (drop if cannot be coerced),
    and replace NaNs/inf with 0.
    """
    # 1) drop columns with non-scalar values
    bad_cols = []
    for c in df.columns:
        s = df[c]
        sample = s.head(1000)
        if sample.map(lambda v: isinstance(v, (dict, list, tuple, set))).any():
            bad_cols.append(c)
    if bad_cols:
        df = df.drop(columns=bad_cols)

    # 2) coerce object columns to numeric or drop
    for c in list(df.select_dtypes(include=["object"]).columns):
        numeric = pd.to_numeric(df[c], errors="coerce")
        if numeric.notna().any():
            df[c] = numeric.fillna(0)
        else:
            df = df.drop(columns=[c])

    # 3) final safety
    df = df.replace([float("inf"), float("-inf")], 0).fillna(0)
    # any leftover object → numeric
    for c in df.columns:
        if df[c].dtype == object:
            df[c] = pd.to_numeric(df[c], errors="coerce").fillna(0)
    return df


# ---------------- IO ----------------
def save_model(model, path=MODEL_PATH):
    joblib.dump(model, path)
    print(f"[+] Model saved in {path}")

def load_model(path=MODEL_PATH):
    model = joblib.load(path)
    print(f"[+] Model loaded from {path}")
    return model

def save_logs_to_file(events, filename="tls_log.json"):
    with open(filename, "w") as f:
        for e in events:
            f.write(json.dumps(e) + "\n")
    print(f"[+] Events saved in {filename}")

def load_logs_from_file(filename="tls_log.json"):
    with open(filename, "r") as f:
        data = [json.loads(line.strip()) for line in f]
    print(f"[+] Loaded {len(data)} events from {filename}")
    return data


# ---------------- Inference ----------------
def analyze_and_tag(events, model):
    """
    Analyze a list of Suricata 'tls' events with a fitted IsolationForest model.
    Uses unified C++ feature extractor and strictly aligns columns to the model schema.
    Returns (anomalies, normals) with each event tagged as {'anomaly': True/False}.
    """
    # 1) Load saved local schema if available
    required_cols = []
    if os.path.exists(FEATURES_FILE):
        with open(FEATURES_FILE) as f:
            required_cols = f.read().splitlines()

    # 2) Extract features via unified C++ extractor (respect local schema if present)
    features_list = extract_features(EVENT_TYPE, events, required_cols if required_cols else None)

    if not features_list:
        print("[!] Empty list of features from C++")
        return [], events

    # 3) Build DataFrame and sanitize
    df = pd.DataFrame(features_list)
    df = _sanitize_tls_df(df)

    # 4) Basic alignment by local schema (if available)
    if required_cols:
        for col in required_cols:
            if col not in df.columns:
                df[col] = 0
        df = df[required_cols]

    # 5) Strict alignment with the model’s training schema
    expected = getattr(model, "feature_names_in_", None)
    if expected is not None:
        for col in expected:
            if col not in df.columns:
                df[col] = 0
        df = df.loc[:, list(expected)]
    elif required_cols:
        df = df.loc[:, required_cols]

    # 6) Safety: fill NaNs and coerce non-numeric to numeric
    df = df.fillna(0)
    for c in df.columns:
        if df[c].dtype == object:
            df[c] = pd.to_numeric(df[c], errors="coerce").fillna(0)

    # 7) Predict and tag anomalies
    predictions = model.predict(df)

    anomalies, normals = [], []
    for i, event in enumerate(events):
        is_anomaly = predictions[i] == -1
        event["anomaly"] = bool(is_anomaly)
        (anomalies if is_anomaly else normals).append(event)

    print(f"[+] Total events: {len(events)}")
    print(f"[+] Anomalies: {len(anomalies)} | Normal: {len(normals)}")

    return anomalies, normals


# ---------------- CLI / Training ----------------
def main(args=None):
    if args is None:
        parser = argparse.ArgumentParser(description="AI-analysis TLS-events from Suricata")
        parser.add_argument("--retrain", action="store_true")
        parser.add_argument("--analyze", action="store_true")
        parser.add_argument("--file", type=str)
        parser.add_argument("--live", action="store_true")
        args = parser.parse_args()

    # Load events
    if args.file:
        parsed = load_logs_from_file(args.file)
    else:
        r = redis.Redis(host="localhost", port=6379, decode_responses=True)
        raw = r.lrange(EVENT_TYPE, 0, -1)
        parsed = [json.loads(x) for x in raw]
        print(f"[+] Loaded {len(parsed)} events from Redis")

    # Extract all features for training (no forced schema) + sanitize
    features_list = extract_features(EVENT_TYPE, parsed, None)
    df = pd.DataFrame(features_list)
    df = _sanitize_tls_df(df)

    # Train or load
    if args.retrain or not os.path.exists(MODEL_PATH):
        model = IsolationForest(n_estimators=100, contamination=0.05, random_state=42)
        model.fit(df)
        save_model(model)
        # Persist training schema AFTER sanitize
        with open(FEATURES_FILE, "w") as f:
            f.write("\n".join(df.columns))
        print(f"[+] The list of features is saved in {FEATURES_FILE}")
    else:
        model = load_model()

    # Analyze
    if args.analyze:
        anomalies, normals = analyze_and_tag(parsed, model)
        save_logs_to_file(anomalies, "tls_anomalies.json")

        df_anomalies = pd.DataFrame([
            {"src_ip": e.get("src_ip"), "dest_ip": e.get("dest_ip")}
            for e in anomalies
        ])

        print("\n=== Top 10 Sources of Anomalous TLS Traffic:")
        if not df_anomalies.empty and "src_ip" in df_anomalies:
            print(df_anomalies["src_ip"].value_counts().head(10))
        else:
            print("No data")

        print("\n=== Top 10 recipients of anomalous TLS traffic:")
        if not df_anomalies.empty and "dest_ip" in df_anomalies:
            print(df_anomalies["dest_ip"].value_counts().head(10))
        else:
            print("No data")


if __name__ == "__main__":
    main()
