tuned_cache.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. import functools
  2. import sys
  3. from copy import deepcopy
  4. assert 'joblib' not in sys.modules, 'Import tuned cache before joblib'
  5. # noinspection PyProtectedMember,PyPep8
  6. import joblib
  7. # noinspection PyProtectedMember,PyPep8
  8. from joblib.func_inspect import _clean_win_chars
  9. # noinspection PyProtectedMember,PyPep8
  10. from joblib.memory import MemorizedFunc, _FUNCTION_HASHES, NotMemorizedFunc, Memory
  11. _FUNC_NAMES = {}
  12. # noinspection SpellCheckingInspection
  13. class TunedMemory(Memory):
  14. def cache(self, func=None, ignore=None, verbose=None, mmap_mode=False):
  15. """ Decorates the given function func to only compute its return
  16. value for input arguments not cached on disk.
  17. Parameters
  18. ----------
  19. func: callable, optional
  20. The function to be decorated
  21. ignore: list of strings
  22. A list of arguments name to ignore in the hashing
  23. verbose: integer, optional
  24. The verbosity mode of the function. By default that
  25. of the memory object is used.
  26. mmap_mode: {None, 'r+', 'r', 'w+', 'c'}, optional
  27. The memmapping mode used when loading from cache
  28. numpy arrays. See numpy.load for the meaning of the
  29. arguments. By default that of the memory object is used.
  30. Returns
  31. -------
  32. decorated_func: MemorizedFunc object
  33. The returned object is a MemorizedFunc object, that is
  34. callable (behaves like a function), but offers extra
  35. methods for cache lookup and management. See the
  36. documentation for :class:`joblib.memory.MemorizedFunc`.
  37. """
  38. if func is None:
  39. # Partial application, to be able to specify extra keyword
  40. # arguments in decorators
  41. return functools.partial(self.cache, ignore=ignore,
  42. verbose=verbose, mmap_mode=mmap_mode)
  43. if self.store_backend is None:
  44. return NotMemorizedFunc(func)
  45. if verbose is None:
  46. verbose = self._verbose
  47. if mmap_mode is False:
  48. mmap_mode = self.mmap_mode
  49. if isinstance(func, TunedMemorizedFunc):
  50. func = func.func
  51. return TunedMemorizedFunc(func, location=self.store_backend,
  52. backend=self.backend,
  53. ignore=ignore, mmap_mode=mmap_mode,
  54. compress=self.compress,
  55. verbose=verbose, timestamp=self.timestamp)
  56. class TunedMemorizedFunc(MemorizedFunc):
  57. def __call__(self, *args, **kwargs):
  58. # Also store in the in-memory store of function hashes
  59. if self.func not in _FUNCTION_HASHES:
  60. is_named_callable = (hasattr(self.func, '__name__') and
  61. self.func.__name__ != '<lambda>')
  62. if is_named_callable:
  63. # Don't do this for lambda functions or strange callable
  64. # objects, as it ends up being too fragile
  65. func_hash = self._hash_func()
  66. try:
  67. _FUNCTION_HASHES[self.func] = func_hash
  68. except TypeError:
  69. # Some callable are not hashable
  70. pass
  71. # return same result as before
  72. return MemorizedFunc.__call__(self, *args, **kwargs)
  73. old_get_func_name = joblib.func_inspect.get_func_name
  74. def tuned_get_func_name(func, resolv_alias=True, win_characters=True):
  75. if (func, resolv_alias, win_characters) not in _FUNC_NAMES:
  76. _FUNC_NAMES[(func, resolv_alias, win_characters)] = old_get_func_name(func, resolv_alias, win_characters)
  77. if len(_FUNC_NAMES) > 1000:
  78. # keep cache small and fast
  79. for idx, k in enumerate(_FUNC_NAMES.keys()):
  80. if idx % 2:
  81. del _FUNC_NAMES[k]
  82. # print('cache size ', len(_FUNC_NAMES))
  83. return deepcopy(_FUNC_NAMES[(func, resolv_alias, win_characters)])
  84. joblib.func_inspect.get_func_name = tuned_get_func_name
  85. joblib.memory.get_func_name = tuned_get_func_name
  86. def main():
  87. class A:
  88. test_cache = TunedMemory('.cache/test_cache', verbose=1)
  89. def __init__(self, a):
  90. self.a = a
  91. self.compute = self.test_cache.cache(self.compute)
  92. def compute(self):
  93. return self.a + 1
  94. a1, a2 = A(2), A(2)
  95. print(a1.compute())
  96. print('---')
  97. print(a2.compute())
  98. print('---')
  99. a1.a = 3
  100. print(a1.compute())
  101. print('---')
  102. print(a2.compute())
  103. print('---')
  104. if __name__ == '__main__':
  105. main()