tuned_cache.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import functools
  2. import sys
  3. from copy import deepcopy
  4. assert 'joblib' not in sys.modules, 'Import tuned cache before joblib'
  5. # noinspection PyProtectedMember,PyPep8
  6. import joblib
  7. # noinspection PyProtectedMember,PyPep8
  8. from joblib._compat import PY3_OR_LATER
  9. # noinspection PyProtectedMember,PyPep8
  10. from joblib.func_inspect import _clean_win_chars
  11. # noinspection PyProtectedMember,PyPep8
  12. from joblib.memory import MemorizedFunc, _FUNCTION_HASHES, NotMemorizedFunc, Memory
  13. _FUNC_NAMES = {}
  14. # noinspection SpellCheckingInspection
  15. class TunedMemory(Memory):
  16. def cache(self, func=None, ignore=None, verbose=None, mmap_mode=False):
  17. """ Decorates the given function func to only compute its return
  18. value for input arguments not cached on disk.
  19. Parameters
  20. ----------
  21. func: callable, optional
  22. The function to be decorated
  23. ignore: list of strings
  24. A list of arguments name to ignore in the hashing
  25. verbose: integer, optional
  26. The verbosity mode of the function. By default that
  27. of the memory object is used.
  28. mmap_mode: {None, 'r+', 'r', 'w+', 'c'}, optional
  29. The memmapping mode used when loading from cache
  30. numpy arrays. See numpy.load for the meaning of the
  31. arguments. By default that of the memory object is used.
  32. Returns
  33. -------
  34. decorated_func: MemorizedFunc object
  35. The returned object is a MemorizedFunc object, that is
  36. callable (behaves like a function), but offers extra
  37. methods for cache lookup and management. See the
  38. documentation for :class:`joblib.memory.MemorizedFunc`.
  39. """
  40. if func is None:
  41. # Partial application, to be able to specify extra keyword
  42. # arguments in decorators
  43. return functools.partial(self.cache, ignore=ignore,
  44. verbose=verbose, mmap_mode=mmap_mode)
  45. if self.store_backend is None:
  46. return NotMemorizedFunc(func)
  47. if verbose is None:
  48. verbose = self._verbose
  49. if mmap_mode is False:
  50. mmap_mode = self.mmap_mode
  51. if isinstance(func, TunedMemorizedFunc):
  52. func = func.func
  53. return TunedMemorizedFunc(func, location=self.store_backend,
  54. backend=self.backend,
  55. ignore=ignore, mmap_mode=mmap_mode,
  56. compress=self.compress,
  57. verbose=verbose, timestamp=self.timestamp)
  58. class TunedMemorizedFunc(MemorizedFunc):
  59. def __call__(self, *args, **kwargs):
  60. # Also store in the in-memory store of function hashes
  61. if self.func not in _FUNCTION_HASHES:
  62. if PY3_OR_LATER:
  63. is_named_callable = (hasattr(self.func, '__name__') and
  64. self.func.__name__ != '<lambda>')
  65. else:
  66. is_named_callable = (hasattr(self.func, 'func_name') and
  67. self.func.func_name != '<lambda>')
  68. if is_named_callable:
  69. # Don't do this for lambda functions or strange callable
  70. # objects, as it ends up being too fragile
  71. func_hash = self._hash_func()
  72. try:
  73. _FUNCTION_HASHES[self.func] = func_hash
  74. except TypeError:
  75. # Some callable are not hashable
  76. pass
  77. # return same result as before
  78. return MemorizedFunc.__call__(self, *args, **kwargs)
  79. old_get_func_name = joblib.func_inspect.get_func_name
  80. def tuned_get_func_name(func, resolv_alias=True, win_characters=True):
  81. if (func, resolv_alias, win_characters) not in _FUNC_NAMES:
  82. _FUNC_NAMES[(func, resolv_alias, win_characters)] = old_get_func_name(func, resolv_alias, win_characters)
  83. if len(_FUNC_NAMES) > 1000:
  84. # keep cache small and fast
  85. for idx, k in enumerate(_FUNC_NAMES.keys()):
  86. if idx % 2:
  87. del _FUNC_NAMES[k]
  88. # print('cache size ', len(_FUNC_NAMES))
  89. return deepcopy(_FUNC_NAMES[(func, resolv_alias, win_characters)])
  90. joblib.func_inspect.get_func_name = tuned_get_func_name
  91. joblib.memory.get_func_name = tuned_get_func_name
  92. def main():
  93. class A:
  94. test_cache = TunedMemory('.cache/test_cache', verbose=1)
  95. def __init__(self, a):
  96. self.a = a
  97. self.compute = self.test_cache.cache(self.compute)
  98. def compute(self):
  99. return self.a + 1
  100. a1, a2 = A(2), A(2)
  101. print(a1.compute())
  102. print('---')
  103. print(a2.compute())
  104. print('---')
  105. a1.a = 3
  106. print(a1.compute())
  107. print('---')
  108. print(a2.compute())
  109. print('---')
  110. if __name__ == '__main__':
  111. main()