from __future__ import annotations

import json
import math
from pathlib import Path

import pandas as pd
from PIL import Image, ImageDraw, ImageFont


BASE = Path("outputs/resting_youth_ai_outlook")
OUT = Path("outputs/tfsc_manuscript")
FIG = OUT / "figures"
OUT.mkdir(parents=True, exist_ok=True)
FIG.mkdir(parents=True, exist_ok=True)


def load(name: str) -> pd.DataFrame:
    return pd.read_json(BASE / f"{name}.json")


def font(size: int, bold: bool = False) -> ImageFont.FreeTypeFont:
    candidates = [
        "/System/Library/Fonts/Supplemental/Arial Unicode.ttf",
        "/System/Library/Fonts/Supplemental/Arial.ttf",
        "/Library/Fonts/Arial Unicode.ttf",
        "/System/Library/Fonts/Helvetica.ttc",
    ]
    for path in candidates:
        try:
            return ImageFont.truetype(path, size=size, index=1 if bold and path.endswith(".ttc") else 0)
        except Exception:
            continue
    return ImageFont.load_default()


def draw_text(draw: ImageDraw.ImageDraw, xy, text, size=22, fill="#111827", bold=False, anchor=None):
    draw.text(xy, str(text), font=font(size, bold), fill=fill, anchor=anchor)


def fmt(value, digits=1):
    if value is None or (isinstance(value, float) and math.isnan(value)):
        return ""
    return f"{float(value):,.{digits}f}"


yearly = load("yearly")
annual = load("annual_risk")
occ = load("occ_risk_2025")
occ_year = load("occ_year")
sex = load("sex_occ_2025")
edu = load("edu_occ_2025")
trend = load("trend")

latest = yearly.loc[yearly["year"].eq(2025)].iloc[0]
latest_risk = annual.loc[annual["year"].eq(2025)].iloc[0]

OCC_EN = {
    "관리자": "Managers",
    "전문가 및 관련 종사자": "Professionals",
    "사무 종사자": "Clerical",
    "서비스 종사자": "Service",
    "판매 종사자": "Sales",
    "농림어업 숙련 종사자": "Skilled agri.",
    "기능원 및 관련 기능 종사자": "Craft",
    "장치·기계 조작 및 조립 종사자": "Operators",
    "단순노무 종사자": "Elementary",
}

EDU_EN = {
    "중졸": "Middle school",
    "고졸": "High school",
    "전문대졸": "Junior college",
    "대졸": "University",
    "대학원 이상": "Graduate school",
}


def draw_line_chart() -> Path:
    width, height = 1800, 1050
    img = Image.new("RGB", (width, height), "white")
    d = ImageDraw.Draw(img)
    margin = 120
    plot_w = width - margin * 2
    plot_h = 690
    top = 210
    left = margin
    years = yearly["year"].astype(int).tolist()
    rest = yearly["resting_youth_thousand"].tolist()
    occ_resp = yearly["resting_with_occ_thousand"].tolist()
    high = annual["high_ai_pressure_share_pct"].tolist()

    draw_text(d, (left, 70), "Figure 1. Resting youth and occupational AI-risk portfolio, 2021-2025", 34, bold=True)
    draw_text(
        d,
        (left, 125),
        "Weighted estimates from Korea's August supplementary labor-force microdata; population in thousand persons.",
        22,
        fill="#4b5563",
    )

    def sx(year):
        return left + (year - min(years)) / (max(years) - min(years)) * plot_w

    y_min, y_max = 120, 500
    def sy_pop(v):
        return top + plot_h - (v - y_min) / (y_max - y_min) * plot_h

    y2_min, y2_max = 0, 45
    def sy_pct(v):
        return top + plot_h - (v - y2_min) / (y2_max - y2_min) * plot_h

    for i in range(0, 6):
        y = top + i * plot_h / 5
        d.line((left, y, left + plot_w, y), fill="#e5e7eb", width=2)
    d.rectangle((left, top, left + plot_w, top + plot_h), outline="#9ca3af", width=2)

    for tick in [150, 250, 350, 450]:
        y = sy_pop(tick)
        d.line((left - 8, y, left, y), fill="#6b7280", width=2)
        draw_text(d, (left - 15, y), str(tick), 18, fill="#374151", anchor="rm")
    draw_text(d, (left - 70, top + plot_h / 2), "Thousand persons", 18, fill="#374151", anchor="mm")

    for tick in [0, 15, 30, 45]:
        y = sy_pct(tick)
        d.line((left + plot_w, y, left + plot_w + 8, y), fill="#6b7280", width=2)
        draw_text(d, (left + plot_w + 18, y), f"{tick}%", 18, fill="#374151", anchor="lm")
    draw_text(d, (left + plot_w + 85, top + plot_h / 2), "High-risk share", 18, fill="#374151", anchor="mm")

    series = [
        (rest, sy_pop, "#2563eb", "Resting youth"),
        (occ_resp, sy_pop, "#059669", "Desired occupation respondents"),
        (high, sy_pct, "#dc2626", "High AI-pressure share"),
    ]
    for values, sy, color, _ in series:
        pts = [(sx(y), sy(v)) for y, v in zip(years, values)]
        d.line(pts, fill=color, width=5)
        for x, y in pts:
            d.ellipse((x - 8, y - 8, x + 8, y + 8), fill=color)

    for year in years:
        x = sx(year)
        d.line((x, top + plot_h, x, top + plot_h + 8), fill="#6b7280", width=2)
        draw_text(d, (x, top + plot_h + 28), str(year), 20, fill="#374151", anchor="mt")

    legend_y = top + plot_h + 85
    legend_x = left
    for _, _, color, label in series:
        d.rectangle((legend_x, legend_y, legend_x + 28, legend_y + 18), fill=color)
        draw_text(d, (legend_x + 40, legend_y + 9), label, 21, anchor="lm")
        legend_x += 390

    path = FIG / "figure1_trends.png"
    img.save(path, quality=95)
    return path


def draw_risk_matrix() -> Path:
    width, height = 1900, 1200
    img = Image.new("RGB", (width, height), "white")
    d = ImageDraw.Draw(img)
    left, top, right, bottom = 240, 230, 1680, 1010
    plot_w, plot_h = right - left, bottom - top

    draw_text(d, (left, 70), "Figure 2. Occupational aspiration portfolio by AI substitution pressure and demand outlook, 2025", 33, bold=True)
    draw_text(d, (left, 125), "Bubble area is proportional to the number of resting youth desiring each occupational group.", 22, fill="#4b5563")

    def sx(score):
        return left + score / 100 * plot_w

    def sy(score):
        return bottom - score / 100 * plot_h

    for tick in range(0, 101, 20):
        x = sx(tick)
        y = sy(tick)
        d.line((x, top, x, bottom), fill="#e5e7eb", width=2)
        d.line((left, y, right, y), fill="#e5e7eb", width=2)
        draw_text(d, (x, bottom + 25), str(tick), 18, fill="#374151", anchor="mt")
        draw_text(d, (left - 20, y), str(tick), 18, fill="#374151", anchor="rm")

    d.rectangle((left, top, right, bottom), outline="#9ca3af", width=2)
    d.line((sx(70), top, sx(70), bottom), fill="#ef4444", width=4)
    d.line((left, sy(50), right, sy(50)), fill="#f97316", width=4)

    draw_text(d, ((left + right) / 2, bottom + 75), "AI substitution pressure score", 24, anchor="mm")
    draw_text(d, (left, top - 35), "Demand outlook score", 22, anchor="lm")
    draw_text(d, (sx(72), top + 30), "High AI pressure", 20, fill="#b91c1c")
    draw_text(d, (left + 30, sy(47)), "Weak long-term outlook", 20, fill="#c2410c")

    max_w = occ["weighted_thousand"].max()
    colors = {
        "높음": "#ef4444",
        "중간-높음": "#f59e0b",
        "중간": "#64748b",
        "중간-낮음": "#10b981",
        "낮음": "#22c55e",
    }
    label_offsets = {
        "전문가 및 관련 종사자": (22, -38),
        "사무 종사자": (24, 0),
        "서비스 종사자": (-24, 36),
        "장치·기계 조작 및 조립 종사자": (-24, 44),
        "기능원 및 관련 기능 종사자": (22, 28),
        "판매 종사자": (-24, 22),
        "단순노무 종사자": (-24, 30),
        "관리자": (22, -28),
    }
    for _, row in occ.iterrows():
        x = sx(row["ai_substitution_score"])
        y = sy(row["long_term_outlook_score"])
        r = 22 + math.sqrt(row["weighted_thousand"] / max_w) * 45
        color = colors.get(row["risk_band"], "#6b7280")
        d.ellipse((x - r, y - r, x + r, y + r), fill=color, outline="#111827", width=3)
        dx, dy = label_offsets.get(row["occ_label"], (20, 0))
        anchor = "lm" if dx >= 0 else "rm"
        label = OCC_EN.get(row["occ_label"], row["occ_label"])
        draw_text(d, (x + dx, y + dy), f"{label} ({row['share_pct']:.1f}%)", 20, fill="#111827", anchor=anchor)

    path = FIG / "figure2_risk_matrix.png"
    img.save(path, quality=95)
    return path


def compute_segments() -> dict:
    occ_scores = occ[["취업희망직종코드", "ai_substitution_score", "long_term_outlook_score"]]

    def summarize_segment(df: pd.DataFrame, group_col: str) -> pd.DataFrame:
        merged = df.merge(occ_scores, on="취업희망직종코드", how="left")
        out = []
        for key, g in merged.groupby(group_col):
            total = g["weighted_thousand"].sum()
            out.append(
                {
                    group_col: key,
                    "weighted_thousand": total,
                    "avg_ai_substitution_score": (g["weighted_thousand"] * g["ai_substitution_score"]).sum() / total,
                    "high_ai_pressure_share_pct": g.loc[g["ai_substitution_score"].ge(70), "weighted_thousand"].sum() / total * 100,
                    "weak_outlook_share_pct": g.loc[g["long_term_outlook_score"].lt(50), "weighted_thousand"].sum() / total * 100,
                    "broad_ai_exposure_share_pct": g.loc[g["ai_substitution_score"].ge(55), "weighted_thousand"].sum() / total * 100,
                }
            )
        return pd.DataFrame(out).sort_values("weighted_thousand", ascending=False)

    sex_summary = summarize_segment(sex, "sex_label")
    edu_summary = summarize_segment(edu, "edu_label")

    thresholds = []
    for threshold in [50, 55, 60, 65, 70, 75, 80]:
        share = occ.loc[occ["ai_substitution_score"].ge(threshold), "weighted_thousand"].sum() / occ["weighted_thousand"].sum() * 100
        count = occ.loc[occ["ai_substitution_score"].ge(threshold), "weighted_thousand"].sum()
        thresholds.append({"threshold": threshold, "exposed_thousand": count, "exposed_share_pct": share})

    scenario = []
    # Baseline scores are interpreted as a medium-adoption scenario. The low/high scenarios
    # shrink or stretch the distance from 50 to express adoption uncertainty without changing
    # the ordinal structure of occupational exposure.
    for name, factor in [("low adoption", 0.8), ("baseline", 1.0), ("high adoption", 1.2)]:
        scores = occ["ai_substitution_score"].apply(lambda x: max(0, min(100, 50 + (x - 50) * factor)))
        avg = (occ["weighted_thousand"] * scores).sum() / occ["weighted_thousand"].sum()
        equiv = occ["weighted_thousand"].sum() * avg / 100
        scenario.append(
            {
                "scenario": name,
                "avg_ai_substitution_score": avg,
                "task_exposure_equivalent_thousand": equiv,
            }
        )

    return {
        "sex_summary": sex_summary,
        "edu_summary": edu_summary,
        "threshold_sensitivity": pd.DataFrame(thresholds),
        "scenario_sensitivity": pd.DataFrame(scenario),
    }


def draw_segment_chart(segment_tables: dict) -> Path:
    df = segment_tables["edu_summary"].copy()
    order = ["고졸", "전문대졸", "대졸", "중졸"]
    df["order"] = df["edu_label"].map({v: i for i, v in enumerate(order)}).fillna(99)
    df = df.sort_values("order")
    width, height = 1600, 900
    img = Image.new("RGB", (width, height), "white")
    d = ImageDraw.Draw(img)
    left, top, right, bottom = 180, 190, 1450, 720
    draw_text(d, (left, 70), "Figure 3. Exposure profile by education among resting youth with desired occupations, 2025", 31, bold=True)
    draw_text(d, (left, 120), "High-risk is defined as AI substitution pressure score >= 70; broad exposure uses score >= 55.", 21, fill="#4b5563")
    d.rectangle((left, top, right, bottom), outline="#9ca3af", width=2)
    for tick in range(0, 101, 20):
        x = left + tick / 100 * (right - left)
        d.line((x, top, x, bottom), fill="#e5e7eb", width=2)
        draw_text(d, (x, bottom + 22), f"{tick}%", 18, fill="#374151", anchor="mt")

    bar_h = 55
    gap = 55
    y = top + 65
    for _, row in df.iterrows():
        label = EDU_EN.get(row["edu_label"], row["edu_label"])
        high = row["high_ai_pressure_share_pct"]
        broad = row["broad_ai_exposure_share_pct"]
        x_broad = left + broad / 100 * (right - left)
        x_high = left + high / 100 * (right - left)
        draw_text(d, (left - 25, y + bar_h / 2), label, 22, anchor="rm")
        d.rectangle((left, y, x_broad, y + bar_h), fill="#bfdbfe")
        d.rectangle((left, y, x_high, y + bar_h), fill="#ef4444")
        draw_text(d, (x_broad + 12, y + bar_h / 2), f"{broad:.1f}% broad", 19, fill="#1f2937", anchor="lm")
        draw_text(d, (x_high + 8, y + bar_h / 2), f"{high:.1f}%", 18, fill="#111827", anchor="lm")
        y += bar_h + gap
    d.rectangle((left, bottom + 70, left + 28, bottom + 92), fill="#ef4444")
    draw_text(d, (left + 40, bottom + 81), "High AI-pressure share", 20, anchor="lm")
    d.rectangle((left + 360, bottom + 70, left + 388, bottom + 92), fill="#bfdbfe")
    draw_text(d, (left + 400, bottom + 81), "Broad AI-exposure share", 20, anchor="lm")
    path = FIG / "figure3_education_exposure.png"
    img.save(path, quality=95)
    return path


figures = {
    "figure1": str(draw_line_chart()),
    "figure2": str(draw_risk_matrix()),
}
segments = compute_segments()
figures["figure3"] = str(draw_segment_chart(segments))

for name, table in segments.items():
    table.to_csv(OUT / f"{name}.csv", index=False, encoding="utf-8-sig")
    table.to_json(OUT / f"{name}.json", orient="records", force_ascii=False, indent=2)

summary = {
    "latest": latest.to_dict(),
    "latest_risk": latest_risk.to_dict(),
    "high_pressure_thousand": latest_risk["weighted_occ_thousand"] * latest_risk["high_ai_pressure_share_pct"] / 100,
    "weak_outlook_thousand": latest_risk["weighted_occ_thousand"] * latest_risk["weak_outlook_share_pct"] / 100,
    "task_exposure_equivalent_thousand": latest_risk["weighted_occ_thousand"] * latest_risk["avg_ai_substitution_score"] / 100,
    "figures": figures,
}
(OUT / "tfsc_analysis_summary.json").write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")

print(json.dumps(summary, ensure_ascii=False, indent=2))
