package example;

import com.tangosol.net.CacheFactory;
import com.tangosol.net.NamedCache;

import javax.management.AttributeNotFoundException;
import javax.management.MBeanServer;
import javax.management.ObjectName;
import java.io.IOException;
import java.lang.management.ManagementFactory;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;

public final class CoherenceLocalJmxHitsMissesDemo {

    // ---- Cache identifiers ----
    private static final String CACHE_NAME   = "demo-local";
    private static final String SERVICE_NAME = "demo-local-service";
    private static final String SCHEME_NAME  = "demo-local-scheme";

    // ---- Defaults (CLI overridable) ----
    private static final int    DEFAULT_MAX_UNITS          = 10_000;    // N
    private static final double DEFAULT_PRUNE_FRACTION     = 0.90;      // LowUnits = 90% of N
    private static final int    DEFAULT_THREADS            = 4;         // T
    private static final int    DEFAULT_REQUESTS_PER_THREAD= 50_000;    // M
    private static final long   DEFAULT_GRACE_MILLIS       = 500;       // settle time

    public static void main(final String[] args) throws Exception {
        final Args a = Args.parse(args);

        // 1) Inline single-tier local cache config with high-units=N, low-units=floor(N*pruneFraction)
        final Path cfg = writeLocalCacheConfig(a.maxUnits, a.lowUnits());

        // 2) Enable Coherence + JMX
        System.setProperty("coherence.cacheconfig", cfg.toAbsolutePath().toString());
        System.setProperty("coherence.management", "all"); // expose MBeans in this JVM

        // 3) Start Coherence and obtain the cache
        final NamedCache<String, String> cache = CacheFactory.getCache(CACHE_NAME);

        // 4) Preload exactly N entries: k0..k{N-1}
        for (int i = 0; i < a.maxUnits; i++) {
            final String key = "k" + i;
            final String val = "v" + i;
            cache.put(key, val);
        }

        // 5) Workload: each thread does M/2 gets in-range (0..N-1) and M/2 out-of-range (N..2N-1)
        runWorkloadGetsOnly(cache, a);

        // 6) Grace period and JMX dump (same attributes as your prior example)
        if (a.graceMillis > 0) {
            Thread.sleep(a.graceMillis);
        }
        printSelectedCacheMBeanStats(SERVICE_NAME, CACHE_NAME);

        cache.destroy();
        CacheFactory.shutdown();
    }

    // ---------- Workload: gets only ----------
    private static void runWorkloadGetsOnly(final NamedCache<String, String> cache, final Args a) throws InterruptedException {
        final ExecutorService pool = Executors.newFixedThreadPool(a.threads);
        final int n = a.maxUnits;
        final int getsInRange = a.requestsPerThread / 2;
        final int getsOutOfRange = a.requestsPerThread - getsInRange; // handle odd M

        for (int t = 0; t < a.threads; t++) {
            pool.submit(() -> {
                final ThreadLocalRandom rnd = ThreadLocalRandom.current();

                // In-range gets (expected hits)
                for (int i = 0; i < getsInRange; i++) {
                    final int k = rnd.nextInt(n); // [0, N)
                    cache.get("k" + k);
                }

                // Out-of-range gets (expected misses)
                for (int i = 0; i < getsOutOfRange; i++) {
                    final int k = n + rnd.nextInt(n); // [N, 2N)
                    cache.get("k" + k);
                }
            });
        }
        pool.shutdown();
        pool.awaitTermination(30, TimeUnit.MINUTES);
    }

    // ---------- JMX: print selected attributes ----------
    private static void printSelectedCacheMBeanStats(final String serviceName, final String cacheName) throws Exception {
        final MBeanServer mbs = ManagementFactory.getPlatformMBeanServer();
        final Set<ObjectName> all = mbs.queryNames(new ObjectName("Coherence:type=Cache,*"), null);

        boolean found = false;
        for (final ObjectName on : all) {
            final String svc = on.getKeyProperty("service");
            final String nam = on.getKeyProperty("name");
            if (!quoteAwareEquals(svc, serviceName) || !quoteAwareEquals(nam, cacheName)) {
                continue;
            }
            found = true;
            System.out.println("\n--- " + on + " ---");
            printAttr(mbs, on, "Units");
            printAttr(mbs, on, "HighUnits");
            printAttr(mbs, on, "LowUnits");
            printAttr(mbs, on, "Size");
            printAttr(mbs, on, "TotalGets");
            printAttr(mbs, on, "TotalPuts");
            printAttr(mbs, on, "CacheHits");
            printAttr(mbs, on, "CacheMisses");
            printAttr(mbs, on, "HitProbability");
            printAttr(mbs, on, "CachePrunes");
        }
        if (!found) {
            System.out.printf("No CacheMBean instances found for service=%s, name=%s%n", serviceName, cacheName);
        }
    }

    private static void printAttr(final MBeanServer mbs, final ObjectName on, final String attr) {
        try {
            final Object v = mbs.getAttribute(on, attr);
            System.out.printf("%-14s: %s%n", attr, v);
        } catch (final AttributeNotFoundException nf) {
            // silently skip if not defined on this tier/impl (as per your requirement)
        } catch (final Exception e) {
            System.out.printf("%-14s: <error: %s>%n", attr, e.getClass().getSimpleName());
        }
    }

    private static boolean quoteAwareEquals(final String val, final String expected) {
        if (val == null) return false;
        return val.equals(expected) || val.equals(ObjectName.quote(expected));
    }

    // ---------- Inline cache-config ----------
    private static Path writeLocalCacheConfig(final int highUnits, final int lowUnits) throws IOException {
        final String xml =
                "<?xml version=\"1.0\"?>\n" +
                "<cache-config xmlns=\"http://xmlns.oracle.com/coherence/coherence-cache-config\">\n" +
                "  <caching-scheme-mapping>\n" +
                "    <cache-mapping>\n" +
                "      <cache-name>" + CACHE_NAME + "</cache-name>\n" +
                "      <scheme-name>" + SCHEME_NAME + "</scheme-name>\n" +
                "    </cache-mapping>\n" +
                "  </caching-scheme-mapping>\n" +
                "  <caching-schemes>\n" +
                "    <local-scheme>\n" +
                "      <scheme-name>" + SCHEME_NAME + "</scheme-name>\n" +
                "      <service-name>" + SERVICE_NAME + "</service-name>\n" +
                "      <eviction-policy>LRU</eviction-policy>\n" +
                "      <unit-calculator>FIXED</unit-calculator>\n" + // entries are counted as 1 unit each
                "      <high-units>" + highUnits + "</high-units>\n" +
                "      <low-units>" + lowUnits + "</low-units>\n" +
                "    </local-scheme>\n" +
                "  </caching-schemes>\n" +
                "</cache-config>\n";

        final Path tmp = Files.createTempFile("coh-local-hits-misses-", ".xml");
        Files.write(tmp, xml.getBytes(StandardCharsets.UTF_8));
        return tmp;
    }

    // ---------- CLI ----------
    private static final class Args {
        final int maxUnits;          // N
        final double pruneFraction;  // e.g., 0.90 => LowUnits=0.9*N
        final int threads;           // T
        final int requestsPerThread; // M
        final long graceMillis;

        Args(final int maxUnits, final double pruneFraction, final int threads, final int requestsPerThread, final long graceMillis) {
            this.maxUnits = maxUnits;
            this.pruneFraction = pruneFraction;
            this.threads = threads;
            this.requestsPerThread = requestsPerThread;
            this.graceMillis = graceMillis;
        }

        int lowUnits() {
            final double v = Math.floor(maxUnits * pruneFraction);
            return (int) Math.max(0, Math.min(Integer.MAX_VALUE, v));
        }

        static Args parse(final String[] args) {
            int maxUnits = DEFAULT_MAX_UNITS;
            double pruneFraction = DEFAULT_PRUNE_FRACTION;
            int threads = DEFAULT_THREADS;
            int reqPerThread = DEFAULT_REQUESTS_PER_THREAD;
            long graceMillis = DEFAULT_GRACE_MILLIS;

            for (int i = 0; i < args.length; i++) {
                final String a = args[i];
                switch (a) {
                    case "--max-units":
                        maxUnits = Integer.parseInt(req(args, ++i, "--max-units requires a value"));
                        break;
                    case "--prune-fraction":
                        pruneFraction = Double.parseDouble(req(args, ++i, "--prune-fraction requires a value (e.g. 0.9)"));
                        break;
                    case "--threads":
                        threads = Integer.parseInt(req(args, ++i, "--threads requires a value"));
                        break;
                    case "--requests-per-thread":
                        reqPerThread = Integer.parseInt(req(args, ++i, "--requests-per-thread requires a value"));
                        break;
                    case "--grace-millis":
                        graceMillis = Long.parseLong(req(args, ++i, "--grace-millis requires a value"));
                        break;
                    default:
                        // ignore unknown flags
                }
            }
            return new Args(maxUnits, pruneFraction, threads, reqPerThread, graceMillis);
        }

        private static String req(final String[] args, final int idx, final String err) {
            if (idx >= args.length) throw new IllegalArgumentException(err);
            return args[idx];
        }
    }
}
