#!/usr/bin/env python3
import http.server
import json
import subprocess
import os
from pathlib import Path

# Paths
SCRIPT_DIR = Path(__file__).parent.absolute()
PROVER_PATH = SCRIPT_DIR / "../../../build/webgpu_prover"
VERIFIER_PATH = SCRIPT_DIR / "../../../build/webgpu_verifier"
WASM_PATH = SCRIPT_DIR / "../cpp/build/age_verify.wasm"
SHADER_PATH = SCRIPT_DIR / "../../../shader"
PROOF_FILE = SCRIPT_DIR / "proof_data.gz"

class DemoHandler(http.server.SimpleHTTPRequestHandler):
    def do_GET(self):
        if self.path == '/':
            self.path = '/index.html'
        return http.server.SimpleHTTPRequestHandler.do_GET(self)

    def do_POST(self):
        content_length = int(self.headers['Content-Length'])
        data = json.loads(self.rfile.read(content_length).decode('utf-8'))

        if self.path == '/generate-proof':
            self.handle_generate_proof(data)
        elif self.path == '/verify-proof':
            self.handle_verify_proof(data)
        else:
            self.send_error(404, "Endpoint not found")

    def handle_generate_proof(self, data):
        try:
            if not WASM_PATH.exists():
                raise FileNotFoundError(f"WASM file not found at {WASM_PATH}")

            config = {
                "program": str(WASM_PATH),
                "shader-path": str(SHADER_PATH),
                "packing": 8192,
                "private-indices": [1, 2, 3],
                "args": [
                    {"i64": data['birth_year']},
                    {"i64": data['birth_month']},
                    {"i64": data['birth_day']},
                    {"i64": data['current_year']},
                    {"i64": data['current_month']},
                    {"i64": data['current_day']}
                ]
            }

            print("[Prover] Generating proof...", flush=True)
            result = subprocess.run(
                [str(PROVER_PATH), json.dumps(config)],
                capture_output=True,
                text=True,
                timeout=60,
                cwd=SCRIPT_DIR
            )

            if result.returncode != 0:
                # Check if this is a constraint failure (user under 18) vs technical error
                if not result.stderr or result.stderr.strip() == '':
                    # Empty stderr usually means constraint failure
                    raise Exception("Proof generation failed - Age constraint not satisfied (user may be under 18)")
                else:
                    raise Exception(f"Prover failed: {result.stderr}")

            proof_size = PROOF_FILE.stat().st_size
            print(f"[Prover] Success - {proof_size} bytes", flush=True)

            self.send_response(200)
            self.send_header('Content-Type', 'application/json')
            self.end_headers()
            self.wfile.write(json.dumps({
                'success': True,
                'proof_size': proof_size
            }).encode('utf-8'))

        except Exception as e:
            print(f"[Prover] Error: {e}", flush=True)
            self.send_response(500)
            self.send_header('Content-Type', 'application/json')
            self.end_headers()
            self.wfile.write(json.dumps({
                'success': False,
                'error': str(e)
            }).encode('utf-8'))

    def handle_verify_proof(self, data):
        try:
            # Verifier uses DUMMY birthdate
            config = {
                "program": str(WASM_PATH),
                "shader-path": str(SHADER_PATH),
                "packing": 8192,
                "private-indices": [1, 2, 3],
                "args": [
                    {"i64": 1900},  # Dummy values
                    {"i64": 1},
                    {"i64": 1},
                    {"i64": data['current_year']},
                    {"i64": data['current_month']},
                    {"i64": data['current_day']}
                ]
            }

            print("[Verifier] Verifying proof...", flush=True)
            result = subprocess.run(
                [str(VERIFIER_PATH), json.dumps(config)],
                capture_output=True,
                text=True,
                timeout=30,
                cwd=SCRIPT_DIR
            )

            success = (result.returncode == 0)
            message = 'Age verification successful - User is 18 or older' if success else 'Age verification failed - User is under 18'
            print(f"[Verifier] {message}", flush=True)

            self.send_response(200)
            self.send_header('Content-Type', 'application/json')
            self.end_headers()
            self.wfile.write(json.dumps({
                'success': success,
                'message': message
            }).encode('utf-8'))

        except Exception as e:
            print(f"[Verifier] Error: {e}", flush=True)
            self.send_response(500)
            self.send_header('Content-Type', 'application/json')
            self.end_headers()
            self.wfile.write(json.dumps({
                'success': False,
                'error': str(e)
            }).encode('utf-8'))


def main():
    port = 8000
    print("=" * 60, flush=True)
    print("Age Verification Demo Server", flush=True)
    print("=" * 60, flush=True)
    print(f"Starting server on http://localhost:{port}", flush=True)
    print("Press Ctrl+C to stop\n", flush=True)

    os.chdir(SCRIPT_DIR)
    server = http.server.HTTPServer(('localhost', port), DemoHandler)

    try:
        server.serve_forever()
    except KeyboardInterrupt:
        print("\nShutting down...", flush=True)
        server.shutdown()


if __name__ == '__main__':
    main()
