#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Script de comparaison de scénarios
Fusionne plusieurs simulations et génère comparaisons CSV, MD et SVG
"""

import sys
import os
import csv
import glob
import xml.etree.ElementTree as ET
from typing import List, Dict, Tuple


def trouver_scenarios(pattern: str, dossier_base: str = "resultats") -> List[Tuple[str, str]]:
    """
    Trouve les dossiers de scénarios matchant le pattern
    
    Returns:
        List[(nom_scenario, chemin_dossier)]
    """
    chemin_pattern = os.path.join(dossier_base, pattern)
    dossiers = glob.glob(chemin_pattern)
    
    scenarios = []
    for dossier in sorted(dossiers):
        if os.path.isdir(dossier):
            nom_scenario = os.path.basename(dossier)
            fichier_csv = os.path.join(dossier, "donnees.csv")
            if os.path.exists(fichier_csv):
                scenarios.append((nom_scenario, dossier))
    
    return scenarios


def lire_donnees_scenario(nom_scenario: str, dossier: str) -> List[Dict]:
    """Lit le CSV d'un scénario et ajoute la colonne 'scenario'"""
    fichier_csv = os.path.join(dossier, "donnees.csv")
    donnees = []
    
    with open(fichier_csv, 'r', encoding='utf-8') as f:
        reader = csv.DictReader(f)
        for row in reader:
            row['scenario'] = nom_scenario
            donnees.append(row)
    
    return donnees


def fusionner_donnees(scenarios: List[Tuple[str, str]]) -> List[Dict]:
    """Fusionne les données de tous les scénarios"""
    toutes_donnees = []
    
    for nom_scenario, dossier in scenarios:
        print(f"  Lecture: {nom_scenario}")
        donnees = lire_donnees_scenario(nom_scenario, dossier)
        toutes_donnees.extend(donnees)
    
    return toutes_donnees


def exporter_csv_comparaison(donnees: List[Dict], fichier_sortie: str):
    """Exporte le CSV de comparaison avec colonne 'scenario' en premier"""
    if not donnees:
        print("⚠ Aucune donnée à exporter")
        return
    
    # Ordre des colonnes: scenario en premier, puis le reste
    colonnes = ['scenario'] + [k for k in donnees[0].keys() if k != 'scenario']
    
    with open(fichier_sortie, 'w', newline='', encoding='utf-8') as f:
        writer = csv.DictWriter(f, fieldnames=colonnes)
        writer.writeheader()
        writer.writerows(donnees)
    
    print(f"✓ CSV comparaison exporté: {fichier_sortie}")


def calculer_synthese(scenarios: List[Tuple[str, str]]) -> List[Dict]:
    """Calcule les statistiques de synthèse pour chaque scénario"""
    synthese = []
    
    for nom_scenario, dossier in scenarios:
        donnees = lire_donnees_scenario(nom_scenario, dossier)
        
        if not donnees:
            continue
        
        derniere_annee = donnees[-1]
        
        # Trouver le différentiel maximum
        diff_max = max(float(d['differentiel_pct']) for d in donnees if d['differentiel_pct'])
        
        # Calculer moyenne du différentiel
        diffs = [float(d['differentiel_pct']) for d in donnees if d['differentiel_pct']]
        diff_moyen = sum(diffs) / len(diffs) if diffs else 0
        
        synthese.append({
            'scenario': nom_scenario,
            'duree_ans': int(derniere_annee['annee']),
            'pib_initial': float(donnees[0]['pib']),
            'pib_final': float(derniere_annee['pib']),
            'dette_pub_initiale': float(donnees[0]['dette_publique']),
            'dette_pub_finale': float(derniere_annee['dette_publique']),
            'dette_impl_initiale': float(donnees[0]['dette_implicite']),
            'dette_impl_finale': float(derniere_annee['dette_implicite']),
            'differentiel_max_pct': round(diff_max, 2),
            'differentiel_moyen_pct': round(diff_moyen, 2),
        })
    
    return synthese


def exporter_markdown_comparaison(synthese: List[Dict], nom_comparaison: str, fichier_sortie: str):
    """Génère le fichier markdown de comparaison"""
    with open(fichier_sortie, 'w', encoding='utf-8') as f:
        f.write(f"# Comparaison : {nom_comparaison}\n\n")
        
        f.write("## Tableau Synthèse\n\n")
        f.write("| Scénario | Durée (ans) | PIB final | Dette Pub. finale | Dette Impl. finale | Diff. max | Diff. moyen |\n")
        f.write("|----------|-------------|-----------|-------------------|--------------------|-----------|--------------|\n")
        
        for s in synthese:
            f.write(f"| {s['scenario']} | {s['duree_ans']} | "
                   f"{s['pib_final']:.0f} Mds | "
                   f"{s['dette_pub_finale']:.0f} Mds | "
                   f"{s['dette_impl_finale']:.0f} Mds | "
                   f"{s['differentiel_max_pct']:.1f}% | "
                   f"{s['differentiel_moyen_pct']:.1f}% |\n")
        
        f.write("\n---\n\n")
        f.write("## Classement par Durée\n\n")
        
        classes = sorted(synthese, key=lambda x: x['duree_ans'])
        for i, s in enumerate(classes, 1):
            f.write(f"{i}. **{s['scenario']}** : {s['duree_ans']} ans\n")
        
        f.write("\n---\n\n")
        f.write("## Analyse\n\n")
        
        plus_rapide = min(synthese, key=lambda x: x['duree_ans'])
        plus_long = max(synthese, key=lambda x: x['duree_ans'])
        
        f.write(f"**Scénario le plus rapide** : {plus_rapide['scenario']} ({plus_rapide['duree_ans']} ans)\n\n")
        f.write(f"**Scénario le plus long** : {plus_long['scenario']} ({plus_long['duree_ans']} ans)\n\n")
        
        ecart = plus_long['duree_ans'] - plus_rapide['duree_ans']
        f.write(f"**Écart** : {ecart} ans entre le scénario le plus rapide et le plus long\n\n")
        
        f.write("---\n\n")
        f.write("## Graphiques\n\n")
        f.write("Voir le fichier `comparaison.svg` pour les graphiques de comparaison.\n")
    
    print(f"✓ Markdown comparaison exporté: {fichier_sortie}")


def generer_svg_comparaison(donnees_par_scenario: Dict[str, List[Dict]], 
                            nom_comparaison: str, fichier_sortie: str):
    """Génère les graphiques SVG de comparaison"""
    
    # Couleurs et styles pour chaque scénario
    couleurs = ['#4A90E2', '#50C878', '#E85D75', '#9B59B6', '#F39C12', '#3498DB']
    styles = ['', '5,5', '10,5', '3,3', '8,2', '5,2']
    
    scenarios = list(donnees_par_scenario.keys())
    
    # Dimensions
    largeur_totale = 1200
    hauteur_totale = 1000
    marge = 80
    espacement = 100
    
    largeur_graph = (largeur_totale - 3 * marge) // 2
    hauteur_graph = (hauteur_totale - 4 * espacement) // 3
    
    # Créer SVG
    svg = ET.Element('svg', {
        'width': str(largeur_totale),
        'height': str(hauteur_totale),
        'xmlns': 'http://www.w3.org/2000/svg',
        'version': '1.1'
    })
    
    # Fond blanc
    ET.SubElement(svg, 'rect', {
        'width': str(largeur_totale),
        'height': str(hauteur_totale),
        'fill': 'white'
    })
    
    # Titre
    ET.SubElement(svg, 'text', {
        'x': str(largeur_totale / 2),
        'y': '30',
        'text-anchor': 'middle',
        'font-size': '18',
        'font-weight': 'bold',
        'fill': '#333'
    }).text = f'Comparaison : {nom_comparaison}'
    
    # Fonction helper pour dessiner un graphique
    def dessiner_graphique(x, y, titre, champ, unite):
        # Titre
        ET.SubElement(svg, 'text', {
            'x': str(x + largeur_graph / 2),
            'y': str(y - 10),
            'text-anchor': 'middle',
            'font-size': '14',
            'font-weight': 'bold',
            'fill': '#333'
        }).text = titre
        
        # Grille
        for i in range(6):
            y_pos = y + (hauteur_graph * i / 5)
            ET.SubElement(svg, 'line', {
                'x1': str(x),
                'y1': str(y_pos),
                'x2': str(x + largeur_graph),
                'y2': str(y_pos),
                'stroke': '#e0e0e0',
                'stroke-width': '1'
            })
        
        # Trouver min/max global
        val_max = 0
        for donnees in donnees_par_scenario.values():
            vals = [float(d[champ]) for d in donnees if champ in d and d[champ]]
            if vals:
                val_max = max(val_max, max(vals))
        
        if val_max == 0:
            return
        
        # Dessiner chaque scénario
        for idx, (nom_sc, donnees) in enumerate(donnees_par_scenario.items()):
            couleur = couleurs[idx % len(couleurs)]
            style = styles[idx % len(styles)]
            
            valeurs = [float(d[champ]) for d in donnees if champ in d and d[champ]]
            if not valeurs:
                continue
            
            # Normaliser
            valeurs_norm = [v / val_max for v in valeurs]
            
            # Points
            nb_points = len(valeurs)
            points_x = [x + (largeur_graph * i / (nb_points - 1)) for i in range(nb_points)]
            points_y = [y + (hauteur_graph * (1 - v)) for v in valeurs_norm]
            
            # Chemin
            points = [f"{px},{py}" for px, py in zip(points_x, points_y)]
            chemin = "M " + " L ".join(points)
            
            attrs = {
                'd': chemin,
                'fill': 'none',
                'stroke': couleur,
                'stroke-width': '2'
            }
            if style:
                attrs['stroke-dasharray'] = style
            
            ET.SubElement(svg, 'path', attrs)
        
        # Axes
        ET.SubElement(svg, 'line', {
            'x1': str(x), 'y1': str(y),
            'x2': str(x), 'y2': str(y + hauteur_graph),
            'stroke': '#333', 'stroke-width': '2'
        })
        ET.SubElement(svg, 'line', {
            'x1': str(x), 'y1': str(y + hauteur_graph),
            'x2': str(x + largeur_graph), 'y2': str(y + hauteur_graph),
            'stroke': '#333', 'stroke-width': '2'
        })
        
        # Labels Y
        ET.SubElement(svg, 'text', {
            'x': str(x - 10), 'y': str(y + 5),
            'text-anchor': 'end', 'font-size': '10', 'fill': '#666'
        }).text = f"{val_max:.0f}"
        ET.SubElement(svg, 'text', {
            'x': str(x - 10), 'y': str(y + hauteur_graph + 5),
            'text-anchor': 'end', 'font-size': '10', 'fill': '#666'
        }).text = "0"
        
        # Labels X (décennies)
        duree_max = max(len(d) for d in donnees_par_scenario.values())
        for dec in range(0, duree_max + 1, 10):
            if dec < duree_max:
                x_pos = x + (largeur_graph * dec / (duree_max - 1)) if duree_max > 1 else x
                ET.SubElement(svg, 'line', {
                    'x1': str(x_pos), 'y1': str(y + hauteur_graph),
                    'x2': str(x_pos), 'y2': str(y + hauteur_graph + 5),
                    'stroke': '#333', 'stroke-width': '1'
                })
                ET.SubElement(svg, 'text', {
                    'x': str(x_pos), 'y': str(y + hauteur_graph + 18),
                    'text-anchor': 'middle', 'font-size': '9', 'fill': '#666'
                }).text = f"{dec}"
    
    # 6 graphiques
    dessiner_graphique(marge, espacement, 'PIB (Mds €)', 'pib', 'Mds €')
    dessiner_graphique(marge + largeur_graph + marge, espacement, 
                      'Différentiel (% PIB)', 'differentiel_pct', '%')
    
    dessiner_graphique(marge, espacement * 2 + hauteur_graph,
                      'Dette Publique (% PIB)', 'dette_publique_totale_pct', '%')
    dessiner_graphique(marge + largeur_graph + marge, espacement * 2 + hauteur_graph,
                      'Dette Implicite (% PIB)', 'dette_implicite_pct', '%')
    
    dessiner_graphique(marge, espacement * 3 + hauteur_graph * 2,
                      'Intérêts (Mds €)', 'interets_totaux', 'Mds €')
    dessiner_graphique(marge + largeur_graph + marge, espacement * 3 + hauteur_graph * 2,
                      'Flux Pensions (Mds €)', 'flux_pensions', 'Mds €')
    
    # Légende globale en bas
    leg_x = marge
    leg_y = hauteur_totale - 50
    
    ET.SubElement(svg, 'text', {
        'x': str(leg_x),
        'y': str(leg_y - 10),
        'font-size': '12',
        'font-weight': 'bold',
        'fill': '#333'
    }).text = 'Légende:'
    
    for idx, nom_sc in enumerate(scenarios):
        couleur = couleurs[idx % len(couleurs)]
        style = styles[idx % len(styles)]
        
        x_leg = leg_x + (idx % 3) * 300
        y_leg = leg_y + (idx // 3) * 20
        
        attrs = {
            'x1': str(x_leg),
            'y1': str(y_leg),
            'x2': str(x_leg + 40),
            'y2': str(y_leg),
            'stroke': couleur,
            'stroke-width': '2'
        }
        if style:
            attrs['stroke-dasharray'] = style
        
        ET.SubElement(svg, 'line', attrs)
        ET.SubElement(svg, 'text', {
            'x': str(x_leg + 45),
            'y': str(y_leg + 4),
            'font-size': '10',
            'fill': '#666'
        }).text = nom_sc
    
    # Écrire le fichier
    tree = ET.ElementTree(svg)
    ET.indent(tree, space="  ")
    tree.write(fichier_sortie, encoding='utf-8', xml_declaration=True)
    
    print(f"✓ SVG comparaison exporté: {fichier_sortie}")


def main():
    """Fonction principale"""
    if len(sys.argv) < 3:
        print("Usage: python generer_comparaison.py <nom_comparaison> <pattern>")
        print()
        print("Exemples:")
        print("  python generer_comparaison.py 'Scénarios Belgique' 'belgique_*'")
        print("  python generer_comparaison.py 'Pays Rapides' 'pologne_*,paysbas_*'")
        print()
        sys.exit(1)
    
    nom_comparaison = sys.argv[1]
    patterns = sys.argv[2].split(',')
    
    print("=" * 60)
    print(f"GÉNÉRATION COMPARAISON : {nom_comparaison}")
    print("=" * 60)
    print()
    
    # Trouver tous les scénarios
    scenarios = []
    for pattern in patterns:
        pattern = pattern.strip()
        print(f"Recherche pattern: {pattern}")
        scenarios.extend(trouver_scenarios(pattern))
    
    if not scenarios:
        print("✗ Aucun scénario trouvé !")
        sys.exit(1)
    
    print()
    print(f"✓ {len(scenarios)} scénario(s) trouvé(s):")
    for nom, _ in scenarios:
        print(f"  - {nom}")
    print()
    
    # Créer dossier de sortie
    nom_dossier = nom_comparaison.lower().replace(' ', '_').replace('é', 'e').replace('è', 'e')
    dossier_sortie = f"resultats/comparaison_{nom_dossier}"
    os.makedirs(dossier_sortie, exist_ok=True)
    print(f"Dossier de sortie: {dossier_sortie}")
    print()
    
    # Fusionner les données
    print("Fusion des données...")
    donnees_fusionnees = fusionner_donnees(scenarios)
    print(f"✓ {len(donnees_fusionnees)} lignes fusionnées")
    print()
    
    # Exporter CSV
    print("Génération CSV...")
    exporter_csv_comparaison(donnees_fusionnees, f"{dossier_sortie}/comparaison.csv")
    print()
    
    # Calculer synthèse
    print("Calcul synthèse...")
    synthese = calculer_synthese(scenarios)
    print()
    
    # Exporter Markdown
    print("Génération Markdown...")
    exporter_markdown_comparaison(synthese, nom_comparaison, f"{dossier_sortie}/comparaison.md")
    print()
    
    # Préparer données pour SVG (par scénario)
    donnees_par_scenario = {}
    for nom_scenario, dossier in scenarios:
        donnees_par_scenario[nom_scenario] = lire_donnees_scenario(nom_scenario, dossier)
    
    # Générer SVG
    print("Génération SVG...")
    generer_svg_comparaison(donnees_par_scenario, nom_comparaison, f"{dossier_sortie}/comparaison.svg")
    print()
    
    print("=" * 60)
    print("COMPARAISON TERMINÉE")
    print("=" * 60)
    print()
    print(f"Fichiers générés dans: {dossier_sortie}/")
    print("  - comparaison.csv")
    print("  - comparaison.md")
    print("  - comparaison.svg")


if __name__ == "__main__":
    main()
