#!/usr/bin/env python3
"""Heat GPU Cloud — Provider Client v0.2.0

ご自身の GPU を Heat GPU Cloud に接続するクライアント。

機能:
  - リアルタイム GPU メトリクス送信 (温度/使用率/VRAM/電力/クロック)
  - ネットワーク性能ベンチ (ping/帯域)
  - GPU コンピュートベンチ (matmul TFLOPS — torch があれば)
  - CPU/RAM ベンチ
  - Docker + NVIDIA Container Toolkit 自動検知
  - ジョブ受付準備チェック → executor_ready 判定

使い方:
    python3 heat_provider.py --token YOUR_CLIENT_TOKEN
    HEAT_PROVIDER_TOKEN=xxx python3 heat_provider.py

オプション:
    --token TOKEN          Client Token
    --interval N           Heartbeat 間隔 (秒・既定 30)
    --api URL              API endpoint (既定 https://api.heatgpu.com)
    --max-util N           GPU 使用率上限 (%・既定 80)
    --bench-on-start       起動時にフルベンチマーク実行 (既定 on)
    --no-bench             ベンチ実行をスキップ
    --once                 1回だけ heartbeat 送信して終了
"""
import argparse
import json
import os
import platform
import shutil
import signal
import socket
import subprocess
import sys
import time
import urllib.error
import urllib.request

VERSION = "0.2.0"
DEFAULT_API = "https://api.heatgpu.com"
DEFAULT_INTERVAL = 30
DEFAULT_MAX_UTIL = 80
BENCH_PAYLOAD_URL = "https://heatgpu.com/assets/ads/provider/ogp_1200x630.png"  # ~2MB

# ──────────────────────────────────────────────────────────────────────────────
# Banner
def banner():
    print("─" * 64)
    print(f"  Heat GPU Cloud — Provider Client  v{VERSION}")
    print(f"  https://heatgpu.com/provider/")
    print("─" * 64)


# ──────────────────────────────────────────────────────────────────────────────
# Hardware detection
def detect_gpu_detailed():
    """nvidia-smi で GPU 詳細情報を取得 (拡張版)"""
    fields = [
        "name", "uuid",
        "memory.total", "memory.used", "memory.free",
        "utilization.gpu", "utilization.memory",
        "temperature.gpu", "power.draw", "power.limit",
        "clocks.gr", "clocks.mem", "clocks.sm",
        "pstate", "driver_version", "fan.speed",
    ]
    try:
        r = subprocess.run(
            ["nvidia-smi", f"--query-gpu={','.join(fields)}",
             "--format=csv,noheader,nounits"],
            capture_output=True, text=True, timeout=8,
        )
        if r.returncode != 0:
            return None
        gpus = []
        for line in r.stdout.strip().split("\n"):
            parts = [p.strip() for p in line.split(",")]
            if len(parts) < len(fields):
                continue
            def _i(s):
                try: return int(float(s.replace("[Not Supported]", "0").replace("N/A", "0")))
                except: return 0
            gpus.append({
                "name": parts[0],
                "uuid": parts[1],
                "vram_total_mb": _i(parts[2]),
                "vram_used_mb": _i(parts[3]),
                "vram_free_mb": _i(parts[4]),
                "gpu_util_pct": _i(parts[5]),
                "mem_util_pct": _i(parts[6]),
                "temp_c": _i(parts[7]),
                "power_w": _i(parts[8]),
                "power_limit_w": _i(parts[9]),
                "clock_graphics_mhz": _i(parts[10]),
                "clock_memory_mhz": _i(parts[11]),
                "clock_sm_mhz": _i(parts[12]),
                "pstate": parts[13],
                "driver": parts[14],
                "fan_pct": _i(parts[15]),
            })
        return gpus if gpus else None
    except (FileNotFoundError, subprocess.TimeoutExpired):
        return None
    except Exception as e:
        print(f"  ! GPU 検出エラー: {e}", file=sys.stderr)
        return None


def detect_cpu():
    """CPU 情報 (依存なし・OS別)"""
    info = {
        "model": platform.processor() or "Unknown",
        "cores_logical": os.cpu_count() or 1,
        "arch": platform.machine(),
    }
    try:
        if platform.system() == "Linux":
            with open("/proc/cpuinfo") as f:
                txt = f.read()
            for line in txt.split("\n"):
                if line.startswith("model name"):
                    info["model"] = line.split(":", 1)[1].strip()
                    break
        elif platform.system() == "Darwin":
            r = subprocess.run(["sysctl", "-n", "machdep.cpu.brand_string"],
                               capture_output=True, text=True, timeout=3)
            if r.returncode == 0:
                info["model"] = r.stdout.strip()
    except Exception:
        pass
    return info


def detect_ram_gb():
    """物理 RAM 容量 (GB)"""
    try:
        if platform.system() == "Linux":
            with open("/proc/meminfo") as f:
                for line in f:
                    if line.startswith("MemTotal:"):
                        return round(int(line.split()[1]) / 1024 / 1024, 1)
        elif platform.system() == "Darwin":
            r = subprocess.run(["sysctl", "-n", "hw.memsize"],
                               capture_output=True, text=True, timeout=3)
            if r.returncode == 0:
                return round(int(r.stdout.strip()) / 1024**3, 1)
    except Exception:
        pass
    return 0


def detect_disk_free_gb(path="/"):
    """ディスク空き容量 (GB)"""
    try:
        s = shutil.disk_usage(path)
        return round(s.free / 1024**3, 1)
    except Exception:
        return 0


def detect_docker():
    """Docker + NVIDIA Container Toolkit の有無を確認"""
    out = {"docker": False, "docker_version": None, "nvidia_runtime": False, "ready_for_jobs": False}
    if not shutil.which("docker"):
        return out
    try:
        r = subprocess.run(["docker", "--version"], capture_output=True, text=True, timeout=5)
        if r.returncode == 0:
            out["docker"] = True
            out["docker_version"] = r.stdout.strip().replace("Docker version ", "").split(",")[0]
    except Exception:
        pass
    if out["docker"]:
        try:
            r = subprocess.run(
                ["docker", "info", "--format", "{{json .Runtimes}}"],
                capture_output=True, text=True, timeout=5,
            )
            if r.returncode == 0 and "nvidia" in r.stdout:
                out["nvidia_runtime"] = True
                out["ready_for_jobs"] = True
        except Exception:
            pass
    return out


def detect_system():
    return {
        "os": platform.system(),
        "os_release": platform.release(),
        "os_version": platform.version()[:100],
        "hostname": socket.gethostname(),
        "python": platform.python_version(),
        "client_version": VERSION,
        "arch": platform.machine(),
    }


# ──────────────────────────────────────────────────────────────────────────────
# Benchmarks
def bench_network(api_url):
    """ネットワーク ベンチ: API への ping + payload 取得スループット"""
    out = {"api_ping_ms": None, "download_mbps": None, "upload_mbps": None}

    # Ping (HEAD リクエスト 3回平均)
    samples = []
    for _ in range(3):
        try:
            t0 = time.perf_counter()
            req = urllib.request.Request(api_url + "/api/family/plans", method="HEAD")
            urllib.request.urlopen(req, timeout=5).read()
            samples.append((time.perf_counter() - t0) * 1000)
        except Exception:
            pass
    if samples:
        out["api_ping_ms"] = round(sum(samples) / len(samples), 1)

    # Download throughput (大きめペイロード)
    try:
        t0 = time.perf_counter()
        with urllib.request.urlopen(BENCH_PAYLOAD_URL, timeout=15) as r:
            data = r.read()
        dt = time.perf_counter() - t0
        if dt > 0:
            out["download_mbps"] = round(len(data) * 8 / 1024 / 1024 / dt, 1)
            out["downloaded_bytes"] = len(data)
    except Exception:
        pass

    # Upload throughput (heartbeat ペイロードで近似 — 厳密ではない)
    try:
        payload = b"x" * 256 * 1024  # 256 KB
        t0 = time.perf_counter()
        req = urllib.request.Request(
            api_url + "/api/family/plans",
            data=payload,
            headers={"Content-Type": "application/octet-stream"},
            method="POST",
        )
        try:
            urllib.request.urlopen(req, timeout=10).read()
        except urllib.error.HTTPError:
            # 405 などは OK (帯域だけ測る)
            pass
        dt = time.perf_counter() - t0
        if dt > 0:
            out["upload_mbps"] = round(len(payload) * 8 / 1024 / 1024 / dt, 1)
    except Exception:
        pass

    return out


def bench_cpu():
    """CPU ベンチ: SHA256 を 2 秒間繰り返した回数"""
    import hashlib
    end = time.perf_counter() + 2.0
    count = 0
    data = b"heat-gpu-cloud-cpu-benchmark" * 100
    while time.perf_counter() < end:
        hashlib.sha256(data).digest()
        count += 1
    return {"sha256_per_sec": count // 2}


def bench_gpu():
    """GPU ベンチ: PyTorch があれば matmul TFLOPS、なければ nvidia-smi 静的スペック"""
    out = {"method": "none", "tflops_fp32": None, "tflops_fp16": None, "bandwidth_gb_s": None}
    # 静的スペック (nvidia-smi)
    try:
        r = subprocess.run(
            ["nvidia-smi", "--query-gpu=memory.bandwidth,clocks.gr,clocks.mem",
             "--format=csv,noheader,nounits"],
            capture_output=True, text=True, timeout=5,
        )
        if r.returncode == 0:
            parts = [p.strip() for p in r.stdout.strip().split("\n")[0].split(",")]
            # memory.bandwidth は実装に依存
    except Exception:
        pass

    # PyTorch がインストールされていれば実コンピュートベンチ
    try:
        import torch
        if torch.cuda.is_available():
            dev = torch.device("cuda:0")
            # ウォームアップ
            a = torch.randn(2048, 2048, device=dev, dtype=torch.float32)
            b = torch.randn(2048, 2048, device=dev, dtype=torch.float32)
            for _ in range(3):
                _ = a @ b
            torch.cuda.synchronize()
            # 実測 (5 回平均)
            t0 = time.perf_counter()
            iters = 20
            for _ in range(iters):
                _ = a @ b
            torch.cuda.synchronize()
            dt = (time.perf_counter() - t0) / iters
            # FLOPs = 2 * N^3 (matmul)
            flops = 2 * 2048**3 / dt
            out["tflops_fp32"] = round(flops / 1e12, 2)
            out["method"] = "torch_matmul"
            # FP16
            try:
                a_h = a.half(); b_h = b.half()
                for _ in range(3):
                    _ = a_h @ b_h
                torch.cuda.synchronize()
                t0 = time.perf_counter()
                for _ in range(iters):
                    _ = a_h @ b_h
                torch.cuda.synchronize()
                dt = (time.perf_counter() - t0) / iters
                out["tflops_fp16"] = round(2 * 2048**3 / dt / 1e12, 2)
            except Exception:
                pass
            del a, b
            torch.cuda.empty_cache()
        else:
            out["method"] = "torch_no_cuda"
    except ImportError:
        out["method"] = "no_torch"
    except Exception as e:
        out["method"] = f"error:{type(e).__name__}"

    return out


# ──────────────────────────────────────────────────────────────────────────────
# HTTP
def http_post(url, token, payload, timeout=15):
    data = json.dumps(payload).encode("utf-8")
    req = urllib.request.Request(
        url, data=data,
        headers={
            "Content-Type": "application/json",
            "Authorization": f"Bearer {token}",
            "User-Agent": f"HeatProviderClient/{VERSION}",
        },
        method="POST",
    )
    try:
        with urllib.request.urlopen(req, timeout=timeout) as r:
            return r.status, json.loads(r.read())
    except urllib.error.HTTPError as e:
        body = {}
        try: body = json.loads(e.read())
        except Exception: pass
        return e.code, body
    except Exception as e:
        return None, {"error": str(e)}


def http_get(url, token, timeout=10):
    req = urllib.request.Request(
        url, headers={
            "Authorization": f"Bearer {token}",
            "User-Agent": f"HeatProviderClient/{VERSION}",
        },
    )
    try:
        with urllib.request.urlopen(req, timeout=timeout) as r:
            return r.status, json.loads(r.read())
    except urllib.error.HTTPError as e:
        body = {}
        try: body = json.loads(e.read())
        except Exception: pass
        return e.code, body
    except Exception as e:
        return None, {"error": str(e)}


_running = True
def _stop(_sig=None, _frame=None):
    global _running
    _running = False
    print("\n  停止しています...")
signal.signal(signal.SIGINT, _stop)
signal.signal(signal.SIGTERM, _stop)


# ──────────────────────────────────────────────────────────────────────────────
# Main loop
def main_loop(args):
    sys_info = detect_system()
    print(f"  System:  {sys_info['os']} {sys_info['os_release']} ({sys_info['arch']})")
    print(f"  Host:    {sys_info['hostname']}")
    print(f"  CPU:     {detect_cpu()['model']}  ({os.cpu_count()} cores)")
    print(f"  RAM:     {detect_ram_gb()} GB")
    print(f"  Disk:    {detect_disk_free_gb()} GB free")

    gpus = detect_gpu_detailed()
    if not gpus:
        print()
        print("  ! NVIDIA GPU + ドライバが見つかりません。")
        print("    Heat Provider は NVIDIA GPU (RTX 3060 以上推奨) が必要です。")
        print("    nvidia-smi のインストールを確認してください。")
        sys.exit(1)
    print(f"  GPU:     {len(gpus)} 台検出")
    for i, g in enumerate(gpus):
        print(f"    [{i}] {g['name']}  {g['vram_total_mb']}MB  driver {g['driver']}  power {g['power_w']}/{g['power_limit_w']}W")

    docker = detect_docker()
    if docker["docker"]:
        rt = "+ NVIDIA runtime" if docker["nvidia_runtime"] else "(NVIDIA runtime 未設定)"
        print(f"  Docker:  {docker['docker_version']}  {rt}")
    else:
        print(f"  Docker:  ! 未インストール (ベータ後にジョブ実行で必須)")

    print(f"  API:     {args.api}")
    print(f"  Token:   {args.token[:8]}…{args.token[-4:]}")
    print(f"  Heartbeat: {args.interval}s 間隔")
    print("─" * 64)

    # 起動時ベンチマーク
    bench = {}
    if not args.no_bench:
        print("  ★ 起動時ベンチマーク実行中 (10〜30秒)...")
        print("    [1/3] Network...", end=" ", flush=True)
        bench["network"] = bench_network(args.api)
        n = bench["network"]
        print(f"ping {n.get('api_ping_ms')}ms  ↓{n.get('download_mbps')}Mbps  ↑{n.get('upload_mbps')}Mbps")
        print("    [2/3] CPU...    ", end=" ", flush=True)
        bench["cpu"] = bench_cpu()
        print(f"{bench['cpu']['sha256_per_sec']:,} sha256/s")
        print("    [3/3] GPU...    ", end=" ", flush=True)
        bench["gpu"] = bench_gpu()
        if bench["gpu"]["tflops_fp32"]:
            print(f"{bench['gpu']['tflops_fp32']} TFLOPS (FP32) / {bench['gpu'].get('tflops_fp16','—')} TFLOPS (FP16)")
        else:
            print(f"method={bench['gpu']['method']}  (PyTorch インストール推奨で詳細ベンチ可能)")

        # ベンチ結果送信
        status, body = http_post(
            f"{args.api}/api/provider/client/benchmark",
            args.token,
            {
                "ts": int(time.time()),
                "system": sys_info,
                "cpu": detect_cpu(),
                "ram_gb": detect_ram_gb(),
                "disk_free_gb": detect_disk_free_gb(),
                "docker": docker,
                "bench": bench,
                "gpus_detail": gpus,
            },
            timeout=20,
        )
        if status == 200:
            print(f"  ★ ベンチマーク送信完了 → executor_ready={body.get('executor_ready')}")
        elif status == 401:
            print(f"  ! 認証失敗。Client Token を確認してください。"); sys.exit(2)
        elif status == 403:
            print(f"  ! GPU 未登録。Provider Portal で GPU を追加してください。"); sys.exit(3)
        else:
            print(f"  ! ベンチ送信失敗 status={status}: {str(body)[:120]}")

    print("─" * 64)
    print("  リアルタイム監視を開始しました (Ctrl+C で停止)")
    print("─" * 64)

    consecutive_fail = 0
    iteration = 0
    while _running:
        iteration += 1
        gpus = detect_gpu_detailed() or gpus
        max_util = max(g["gpu_util_pct"] for g in gpus)
        max_temp = max(g["temp_c"] for g in gpus)
        total_power = sum(g["power_w"] for g in gpus)
        is_busy = max_util > args.max_util

        payload = {
            "ts": int(time.time()),
            "client_version": VERSION,
            "system": sys_info,
            "gpus": gpus,
            "is_busy": is_busy,
            "max_util_pct": max_util,
            "max_temp_c": max_temp,
            "total_power_w": total_power,
            "docker_ready": docker.get("ready_for_jobs", False),
        }
        status, body = http_post(
            f"{args.api}/api/provider/client/heartbeat",
            args.token, payload, timeout=10,
        )
        ts = time.strftime("%H:%M:%S")
        if status == 200:
            consecutive_fail = 0
            mark = "BUSY" if is_busy else "OK  "
            print(f"  [{ts}] {mark} util={max_util:3d}%  temp={max_temp}°C  power={total_power}W  vram={gpus[0]['vram_used_mb']}/{gpus[0]['vram_total_mb']}MB")
        elif status == 401:
            print(f"  [{ts}] ! 認証失敗。終了。"); sys.exit(2)
        else:
            consecutive_fail += 1
            print(f"  [{ts}] ! 送信失敗 status={status} body={str(body)[:100]}")
            if consecutive_fail >= 10:
                print(f"  ! 連続失敗 {consecutive_fail} 回。30秒待機します。")
                time.sleep(30)
                consecutive_fail = 0

        # ジョブポーリング (現時点 no-op)
        if not is_busy and iteration % 3 == 0:
            jstatus, jbody = http_get(f"{args.api}/api/provider/client/poll-job", args.token, timeout=10)
            if jstatus == 200 and jbody.get("job"):
                print(f"  [{ts}] JOB 受信: {jbody['job'].get('id')}")
                http_post(f"{args.api}/api/provider/client/submit-result", args.token,
                          {"job_id": jbody["job"]["id"], "status": "skipped",
                           "reason": "executor_beta_not_started", "client_version": VERSION},
                          timeout=20)

        if args.once: break
        for _ in range(args.interval):
            if not _running: break
            time.sleep(1)

    print("  offline 通知を送信中...")
    http_post(f"{args.api}/api/provider/client/heartbeat", args.token,
              {"ts": int(time.time()), "going_offline": True, "client_version": VERSION}, timeout=5)
    print("  終了しました。")


def main():
    p = argparse.ArgumentParser(description=f"Heat GPU Cloud — Provider Client v{VERSION}",
                                formatter_class=argparse.RawDescriptionHelpFormatter,
                                epilog=__doc__)
    p.add_argument("--token", default=os.environ.get("HEAT_PROVIDER_TOKEN", ""))
    p.add_argument("--interval", type=int, default=DEFAULT_INTERVAL)
    p.add_argument("--api", default=os.environ.get("HEAT_PROVIDER_API", DEFAULT_API))
    p.add_argument("--max-util", type=int, default=DEFAULT_MAX_UTIL)
    p.add_argument("--no-bench", action="store_true", help="起動時ベンチをスキップ")
    p.add_argument("--once", action="store_true")
    p.add_argument("--version", action="version", version=f"%(prog)s {VERSION}")
    args = p.parse_args()

    banner()
    if not args.token:
        print("  ! --token を指定するか HEAT_PROVIDER_TOKEN を設定してください。")
        print("    Token: https://heatgpu.com/provider/gpus.html")
        sys.exit(1)
    main_loop(args)


if __name__ == "__main__":
    main()
