From 2df3f358fc2450a0f946a98e0a570b89f3def0bd Mon Sep 17 00:00:00 2001
From: Jack Betteridge <J.Betteridge@imperial.ac.uk>
Date: Fri, 23 Aug 2024 19:35:08 +0100
Subject: [PATCH] Add additional cache tests

---
 test/unit/test_updated_caching.py | 52 +++++++++++++++++++++++++++++--
 1 file changed, 49 insertions(+), 3 deletions(-)

diff --git a/test/unit/test_updated_caching.py b/test/unit/test_updated_caching.py
index 93af9f46c..5066554a1 100644
--- a/test/unit/test_updated_caching.py
+++ b/test/unit/test_updated_caching.py
@@ -1,14 +1,19 @@
+import ctypes
 import pytest
+import os
+import tempfile
 from functools import partial
+from itertools import chain
+from textwrap import dedent
 
-from pyop2.caching import (  # noqa: F401
+from pyop2.caching import (
     disk_only_cache,
     memory_cache,
     memory_and_disk_cache,
-    default_parallel_hashkey,
     clear_memory_cache
 )
-from pyop2.mpi import MPI, COMM_WORLD, comm_cache_keyval  # noqa: F401
+from pyop2.compilation import load
+from pyop2.mpi import MPI, COMM_WORLD
 
 
 class StateIncrement:
@@ -138,3 +143,44 @@ def test_function_over_different_comms(request, state, decorator, uncached_funct
             comm23.Free()
 
     clear_memory_cache(COMM_WORLD)
+
+
+# pyop2/compilation.py uses a custom cache which we test here
+@pytest.mark.parallel(nprocs=2)
+def test_writing_large_so():
+    # This test exercises the compilation caching when handling larger files
+    if COMM_WORLD.rank == 0:
+        preamble = dedent("""\
+            #include <stdio.h>\n
+            void big(double *result){
+            """)
+        variables = (f"v{next(tempfile._get_candidate_names())}" for _ in range(128*1024))
+        lines = (f"  double {v} = {hash(v)/1000000000};\n  *result += {v};\n" for v in variables)
+        program = "\n".join(chain.from_iterable(((preamble, ), lines, ("}\n", ))))
+        with open("big.c", "w") as fh:
+            fh.write(program)
+
+    COMM_WORLD.Barrier()
+    with open("big.c", "r") as fh:
+        program = fh.read()
+
+    if COMM_WORLD.rank == 1:
+        os.remove("big.c")
+
+    fn = load(program, "c", "big", argtypes=(ctypes.c_voidp,), comm=COMM_WORLD)
+    assert fn is not None
+
+
+@pytest.mark.parallel(nprocs=2)
+def test_two_comms_compile_the_same_code():
+    new_comm = COMM_WORLD.Split(color=COMM_WORLD.rank)
+    new_comm.name = "test_two_comms"
+    code = dedent("""\
+        #include <stdio.h>\n
+        void noop(){
+          printf("Do nothing!\\n");
+        }
+        """)
+
+    fn = load(code, "c", "noop", argtypes=(), comm=COMM_WORLD)
+    assert fn is not None