back to Reference (Gold) summary
Reference (Gold): tinydb
Pytest Summary for test tests
status | count |
---|---|
passed | 201 |
total | 201 |
collected | 201 |
Failed pytests:
Patch diff
diff --git a/tinydb/database.py b/tinydb/database.py
index a4ce0e1..4a73c46 100644
--- a/tinydb/database.py
+++ b/tinydb/database.py
@@ -2,10 +2,14 @@
This module contains the main component of TinyDB: the database.
"""
from typing import Dict, Iterator, Set, Type
+
from . import JSONStorage
from .storages import Storage
from .table import Table, Document
from .utils import with_typehint
+
+# The table's base class. This is used to add type hinting from the Table
+# class to TinyDB. Currently, this supports PyCharm, Pyright/VS Code and MyPy.
TableBase: Type[Table] = with_typehint(Table)
@@ -63,28 +67,48 @@ class TinyDB(TableBase):
:param storage: The class of the storage to use. Will be initialized
with ``args`` and ``kwargs``.
"""
+
+ #: The class that will be used to create table instances
+ #:
+ #: .. versionadded:: 4.0
table_class = Table
+
+ #: The name of the default table
+ #:
+ #: .. versionadded:: 4.0
default_table_name = '_default'
+
+ #: The class that will be used by default to create storage instances
+ #:
+ #: .. versionadded:: 4.0
default_storage_class = JSONStorage
- def __init__(self, *args, **kwargs) ->None:
+ def __init__(self, *args, **kwargs) -> None:
"""
Create a new instance of TinyDB.
"""
+
storage = kwargs.pop('storage', self.default_storage_class)
+
+ # Prepare the storage
self._storage: Storage = storage(*args, **kwargs)
+
self._opened = True
self._tables: Dict[str, Table] = {}
def __repr__(self):
- args = ['tables={}'.format(list(self.tables())), 'tables_count={}'.
- format(len(self.tables())), 'default_table_documents_count={}'.
- format(self.__len__()), 'all_tables_documents_count={}'.format(
- ['{}={}'.format(table, len(self.table(table))) for table in
- self.tables()])]
+ args = [
+ 'tables={}'.format(list(self.tables())),
+ 'tables_count={}'.format(len(self.tables())),
+ 'default_table_documents_count={}'.format(self.__len__()),
+ 'all_tables_documents_count={}'.format(
+ ['{}={}'.format(table, len(self.table(table)))
+ for table in self.tables()]),
+ ]
+
return '<{} {}>'.format(type(self).__name__, ', '.join(args))
- def table(self, name: str, **kwargs) ->Table:
+ def table(self, name: str, **kwargs) -> Table:
"""
Get access to a specific table.
@@ -99,41 +123,95 @@ class TinyDB(TableBase):
:param name: The name of the table.
:param kwargs: Keyword arguments to pass to the table class constructor
"""
- pass
- def tables(self) ->Set[str]:
+ if name in self._tables:
+ return self._tables[name]
+
+ table = self.table_class(self.storage, name, **kwargs)
+ self._tables[name] = table
+
+ return table
+
+ def tables(self) -> Set[str]:
"""
Get the names of all tables in the database.
:returns: a set of table names
"""
- pass
- def drop_tables(self) ->None:
+ # TinyDB stores data as a dict of tables like this:
+ #
+ # {
+ # '_default': {
+ # 0: {document...},
+ # 1: {document...},
+ # },
+ # 'table1': {
+ # ...
+ # }
+ # }
+ #
+ # To get a set of table names, we thus construct a set of this main
+ # dict which returns a set of the dict keys which are the table names.
+ #
+ # Storage.read() may return ``None`` if the database file is empty,
+ # so we need to consider this case to and return an empty set in this
+ # case.
+
+ return set(self.storage.read() or {})
+
+ def drop_tables(self) -> None:
"""
Drop all tables from the database. **CANNOT BE REVERSED!**
"""
- pass
- def drop_table(self, name: str) ->None:
+ # We drop all tables from this database by writing an empty dict
+ # to the storage thereby returning to the initial state with no tables.
+ self.storage.write({})
+
+ # After that we need to remember to empty the ``_tables`` dict, so we'll
+ # create new table instances when a table is accessed again.
+ self._tables.clear()
+
+ def drop_table(self, name: str) -> None:
"""
Drop a specific table from the database. **CANNOT BE REVERSED!**
:param name: The name of the table to drop.
"""
- pass
+
+ # If the table is currently opened, we need to forget the table class
+ # instance
+ if name in self._tables:
+ del self._tables[name]
+
+ data = self.storage.read()
+
+ # The database is uninitialized, there's nothing to do
+ if data is None:
+ return
+
+ # The table does not exist, there's nothing to do
+ if name not in data:
+ return
+
+ # Remove the table from the data dict
+ del data[name]
+
+ # Store the updated data back to the storage
+ self.storage.write(data)
@property
- def storage(self) ->Storage:
+ def storage(self) -> Storage:
"""
Get the storage instance used for this TinyDB instance.
:return: This instance's storage
:rtype: Storage
"""
- pass
+ return self._storage
- def close(self) ->None:
+ def close(self) -> None:
"""
Close the database.
@@ -148,7 +226,8 @@ class TinyDB(TableBase):
Upon leaving this context, the ``close`` method will be called.
"""
- pass
+ self._opened = False
+ self.storage.close()
def __enter__(self):
"""
@@ -175,6 +254,9 @@ class TinyDB(TableBase):
"""
return getattr(self.table(self.default_table_name), name)
+ # Here we forward magic methods to the default table instance. These are
+ # not handled by __getattr__ so we need to forward them manually here
+
def __len__(self):
"""
Get the total number of documents in the default table.
@@ -185,7 +267,7 @@ class TinyDB(TableBase):
"""
return len(self.table(self.default_table_name))
- def __iter__(self) ->Iterator[Document]:
+ def __iter__(self) -> Iterator[Document]:
"""
Return an iterator for the default table's documents.
"""
diff --git a/tinydb/middlewares.py b/tinydb/middlewares.py
index 50c2af2..7973012 100644
--- a/tinydb/middlewares.py
+++ b/tinydb/middlewares.py
@@ -3,6 +3,7 @@ Contains the :class:`base class <tinydb.middlewares.Middleware>` for
middlewares and implementations.
"""
from typing import Optional
+
from tinydb import Storage
@@ -17,9 +18,9 @@ class Middleware:
constructor so the middleware chain can be configured properly.
"""
- def __init__(self, storage_cls) ->None:
+ def __init__(self, storage_cls) -> None:
self._storage_cls = storage_cls
- self.storage: Storage = None
+ self.storage: Storage = None # type: ignore
def __call__(self, *args, **kwargs):
"""
@@ -58,7 +59,9 @@ class Middleware:
nested Middleware that itself will initialize the next Middleware and
so on.
"""
+
self.storage = self._storage_cls(*args, **kwargs)
+
return self
def __getattr__(self, name):
@@ -66,6 +69,7 @@ class Middleware:
Forward all unknown attribute calls to the underlying storage, so we
remain as transparent as possible.
"""
+
return getattr(self.__dict__['storage'], name)
@@ -77,15 +81,47 @@ class CachingMiddleware(Middleware):
the last DB state every :attr:`WRITE_CACHE_SIZE` time and reading always
from cache.
"""
+
+ #: The number of write operations to cache before writing to disc
WRITE_CACHE_SIZE = 1000
def __init__(self, storage_cls):
+ # Initialize the parent constructor
super().__init__(storage_cls)
+
+ # Prepare the cache
self.cache = None
self._cache_modified_count = 0
+ def read(self):
+ if self.cache is None:
+ # Empty cache: read from the storage
+ self.cache = self.storage.read()
+
+ # Return the cached data
+ return self.cache
+
+ def write(self, data):
+ # Store data in cache
+ self.cache = data
+ self._cache_modified_count += 1
+
+ # Check if we need to flush the cache
+ if self._cache_modified_count >= self.WRITE_CACHE_SIZE:
+ self.flush()
+
def flush(self):
"""
Flush all unwritten data to disk.
"""
- pass
+ if self._cache_modified_count > 0:
+ # Force-flush the cache by writing the data to the storage
+ self.storage.write(self.cache)
+ self._cache_modified_count = 0
+
+ def close(self):
+ # Flush potentially unwritten data
+ self.flush()
+
+ # Let the storage clean up too
+ self.storage.close()
diff --git a/tinydb/mypy_plugin.py b/tinydb/mypy_plugin.py
index 5a0191a..cef1005 100644
--- a/tinydb/mypy_plugin.py
+++ b/tinydb/mypy_plugin.py
@@ -1,14 +1,38 @@
from typing import TypeVar, Optional, Callable, Dict
+
from mypy.nodes import NameExpr
from mypy.options import Options
from mypy.plugin import Plugin, DynamicClassDefContext
+
T = TypeVar('T')
CB = Optional[Callable[[T], None]]
DynamicClassDef = DynamicClassDefContext
class TinyDBPlugin(Plugin):
-
def __init__(self, options: Options):
super().__init__(options)
+
self.named_placeholders: Dict[str, str] = {}
+
+ def get_dynamic_class_hook(self, fullname: str) -> CB[DynamicClassDef]:
+ if fullname == 'tinydb.utils.with_typehint':
+ def hook(ctx: DynamicClassDefContext):
+ klass = ctx.call.args[0]
+ assert isinstance(klass, NameExpr)
+
+ type_name = klass.fullname
+ assert type_name is not None
+
+ qualified = self.lookup_fully_qualified(type_name)
+ assert qualified is not None
+
+ ctx.api.add_symbol_table_node(ctx.name, qualified)
+
+ return hook
+
+ return None
+
+
+def plugin(_version: str):
+ return TinyDBPlugin
diff --git a/tinydb/operations.py b/tinydb/operations.py
index fdfa678..47c3492 100644
--- a/tinydb/operations.py
+++ b/tinydb/operations.py
@@ -13,39 +13,57 @@ def delete(field):
"""
Delete a given field from the document.
"""
- pass
+ def transform(doc):
+ del doc[field]
+
+ return transform
def add(field, n):
"""
Add ``n`` to a given field in the document.
"""
- pass
+ def transform(doc):
+ doc[field] += n
+
+ return transform
def subtract(field, n):
"""
Subtract ``n`` to a given field in the document.
"""
- pass
+ def transform(doc):
+ doc[field] -= n
+
+ return transform
def set(field, val):
"""
Set a given field to ``val``.
"""
- pass
+ def transform(doc):
+ doc[field] = val
+
+ return transform
def increment(field):
"""
Increment a given field in the document by 1.
"""
- pass
+ def transform(doc):
+ doc[field] += 1
+
+ return transform
def decrement(field):
"""
Decrement a given field in the document by 1.
"""
- pass
+ def transform(doc):
+ doc[field] -= 1
+
+ return transform
diff --git a/tinydb/queries.py b/tinydb/queries.py
index 0ad5c7e..a797b4b 100644
--- a/tinydb/queries.py
+++ b/tinydb/queries.py
@@ -15,15 +15,23 @@ True
>>> q({'val': 1})
False
"""
+
import re
import sys
from typing import Mapping, Tuple, Callable, Any, Union, List, Optional
+
from .utils import freeze
+
if sys.version_info >= (3, 8):
from typing import Protocol
else:
from typing_extensions import Protocol
-__all__ = 'Query', 'QueryLike', 'where'
+
+__all__ = ('Query', 'QueryLike', 'where')
+
+
+def is_sequence(obj):
+ return hasattr(obj, '__iter__')
class QueryLike(Protocol):
@@ -45,12 +53,9 @@ class QueryLike(Protocol):
See also https://mypy.readthedocs.io/en/stable/protocols.html#simple-user-defined-protocols
"""
+ def __call__(self, value: Mapping) -> bool: ...
- def __call__(self, value: Mapping) ->bool:
- ...
-
- def __hash__(self) ->int:
- ...
+ def __hash__(self) -> int: ...
class QueryInstance:
@@ -70,12 +75,14 @@ class QueryInstance:
instance can be used as a key in a dictionary.
"""
- def __init__(self, test: Callable[[Mapping], bool], hashval: Optional[
- Tuple]):
+ def __init__(self, test: Callable[[Mapping], bool], hashval: Optional[Tuple]):
self._test = test
self._hash = hashval
- def __call__(self, value: Mapping) ->bool:
+ def is_cacheable(self) -> bool:
+ return self._hash is not None
+
+ def __call__(self, value: Mapping) -> bool:
"""
Evaluate the query to check if it matches a specified value.
@@ -84,7 +91,10 @@ class QueryInstance:
"""
return self._test(value)
- def __hash__(self) ->int:
+ def __hash__(self) -> int:
+ # We calculate the query hash by using the ``hashval`` object which
+ # describes this query uniquely, so we can calculate a stable hash
+ # value by simply hashing it
return hash(self._hash)
def __repr__(self):
@@ -93,25 +103,32 @@ class QueryInstance:
def __eq__(self, other: object):
if isinstance(other, QueryInstance):
return self._hash == other._hash
+
return False
- def __and__(self, other: 'QueryInstance') ->'QueryInstance':
+ # --- Query modifiers -----------------------------------------------------
+
+ def __and__(self, other: 'QueryInstance') -> 'QueryInstance':
+ # We use a frozenset for the hash as the AND operation is commutative
+ # (a & b == b & a) and the frozenset does not consider the order of
+ # elements
if self.is_cacheable() and other.is_cacheable():
- hashval = 'and', frozenset([self._hash, other._hash])
+ hashval = ('and', frozenset([self._hash, other._hash]))
else:
hashval = None
- return QueryInstance(lambda value: self(value) and other(value),
- hashval)
+ return QueryInstance(lambda value: self(value) and other(value), hashval)
- def __or__(self, other: 'QueryInstance') ->'QueryInstance':
+ def __or__(self, other: 'QueryInstance') -> 'QueryInstance':
+ # We use a frozenset for the hash as the OR operation is commutative
+ # (a | b == b | a) and the frozenset does not consider the order of
+ # elements
if self.is_cacheable() and other.is_cacheable():
- hashval = 'or', frozenset([self._hash, other._hash])
+ hashval = ('or', frozenset([self._hash, other._hash]))
else:
hashval = None
- return QueryInstance(lambda value: self(value) or other(value), hashval
- )
+ return QueryInstance(lambda value: self(value) or other(value), hashval)
- def __invert__(self) ->'QueryInstance':
+ def __invert__(self) -> 'QueryInstance':
hashval = ('not', self._hash) if self.is_cacheable() else None
return QueryInstance(lambda value: not self(value), hashval)
@@ -149,12 +166,18 @@ class Query(QueryInstance):
``False`` depending on whether the documents match the query or not.
"""
- def __init__(self) ->None:
+ def __init__(self) -> None:
+ # The current path of fields to access when evaluating the object
self._path: Tuple[Union[str, Callable], ...] = ()
+ # Prevent empty queries to be evaluated
def notest(_):
raise RuntimeError('Empty query was evaluated')
- super().__init__(test=notest, hashval=(None,))
+
+ super().__init__(
+ test=notest,
+ hashval=(None,)
+ )
def __repr__(self):
return '{}()'.format(type(self).__name__)
@@ -163,16 +186,36 @@ class Query(QueryInstance):
return super().__hash__()
def __getattr__(self, item: str):
+ # Generate a new query object with the new query path
+ # We use type(self) to get the class of the current query in case
+ # someone uses a subclass of ``Query``
query = type(self)()
+
+ # Now we add the accessed item to the query path ...
query._path = self._path + (item,)
+
+ # ... and update the query hash
query._hash = ('path', query._path) if self.is_cacheable() else None
+
return query
def __getitem__(self, item: str):
+ # A different syntax for ``__getattr__``
+
+ # We cannot call ``getattr(item)`` here as it would try to resolve
+ # the name as a method name first, only then call our ``__getattr__``
+ # method. By calling ``__getattr__`` directly, we make sure that
+ # calling e.g. ``Query()['test']`` will always generate a query for a
+ # document's ``test`` field instead of returning a reference to the
+ # ``Query.test`` method
return self.__getattr__(item)
- def _generate_test(self, test: Callable[[Any], bool], hashval: Tuple,
- allow_empty_path: bool=False) ->QueryInstance:
+ def _generate_test(
+ self,
+ test: Callable[[Any], bool],
+ hashval: Tuple,
+ allow_empty_path: bool = False
+ ) -> QueryInstance:
"""
Generate a query based on a test function that first resolves the query
path.
@@ -181,7 +224,27 @@ class Query(QueryInstance):
:param hashval: The hash of the query.
:return: A :class:`~tinydb.queries.QueryInstance` object
"""
- pass
+ if not self._path and not allow_empty_path:
+ raise ValueError('Query has no path')
+
+ def runner(value):
+ try:
+ # Resolve the path
+ for part in self._path:
+ if isinstance(part, str):
+ value = value[part]
+ else:
+ value = part(value)
+ except (KeyError, TypeError):
+ return False
+ else:
+ # Perform the specified test
+ return test(value)
+
+ return QueryInstance(
+ lambda value: runner(value),
+ (hashval if self.is_cacheable() else None)
+ )
def __eq__(self, rhs: Any):
"""
@@ -191,8 +254,10 @@ class Query(QueryInstance):
:param rhs: The value to compare against
"""
- return self._generate_test(lambda value: value == rhs, ('==', self.
- _path, freeze(rhs)))
+ return self._generate_test(
+ lambda value: value == rhs,
+ ('==', self._path, freeze(rhs))
+ )
def __ne__(self, rhs: Any):
"""
@@ -202,10 +267,12 @@ class Query(QueryInstance):
:param rhs: The value to compare against
"""
- return self._generate_test(lambda value: value != rhs, ('!=', self.
- _path, freeze(rhs)))
+ return self._generate_test(
+ lambda value: value != rhs,
+ ('!=', self._path, freeze(rhs))
+ )
- def __lt__(self, rhs: Any) ->QueryInstance:
+ def __lt__(self, rhs: Any) -> QueryInstance:
"""
Test a dict value for being lower than another value.
@@ -213,10 +280,12 @@ class Query(QueryInstance):
:param rhs: The value to compare against
"""
- return self._generate_test(lambda value: value < rhs, ('<', self.
- _path, rhs))
+ return self._generate_test(
+ lambda value: value < rhs,
+ ('<', self._path, rhs)
+ )
- def __le__(self, rhs: Any) ->QueryInstance:
+ def __le__(self, rhs: Any) -> QueryInstance:
"""
Test a dict value for being lower than or equal to another value.
@@ -224,10 +293,12 @@ class Query(QueryInstance):
:param rhs: The value to compare against
"""
- return self._generate_test(lambda value: value <= rhs, ('<=', self.
- _path, rhs))
+ return self._generate_test(
+ lambda value: value <= rhs,
+ ('<=', self._path, rhs)
+ )
- def __gt__(self, rhs: Any) ->QueryInstance:
+ def __gt__(self, rhs: Any) -> QueryInstance:
"""
Test a dict value for being greater than another value.
@@ -235,10 +306,12 @@ class Query(QueryInstance):
:param rhs: The value to compare against
"""
- return self._generate_test(lambda value: value > rhs, ('>', self.
- _path, rhs))
+ return self._generate_test(
+ lambda value: value > rhs,
+ ('>', self._path, rhs)
+ )
- def __ge__(self, rhs: Any) ->QueryInstance:
+ def __ge__(self, rhs: Any) -> QueryInstance:
"""
Test a dict value for being greater than or equal to another value.
@@ -246,18 +319,23 @@ class Query(QueryInstance):
:param rhs: The value to compare against
"""
- return self._generate_test(lambda value: value >= rhs, ('>=', self.
- _path, rhs))
+ return self._generate_test(
+ lambda value: value >= rhs,
+ ('>=', self._path, rhs)
+ )
- def exists(self) ->QueryInstance:
+ def exists(self) -> QueryInstance:
"""
Test for a dict where a provided key exists.
>>> Query().f1.exists()
"""
- pass
+ return self._generate_test(
+ lambda _: True,
+ ('exists', self._path)
+ )
- def matches(self, regex: str, flags: int=0) ->QueryInstance:
+ def matches(self, regex: str, flags: int = 0) -> QueryInstance:
"""
Run a regex test against a dict value (whole string has to match).
@@ -266,9 +344,15 @@ class Query(QueryInstance):
:param regex: The regular expression to use for matching
:param flags: regex flags to pass to ``re.match``
"""
- pass
+ def test(value):
+ if not isinstance(value, str):
+ return False
+
+ return re.match(regex, value, flags) is not None
+
+ return self._generate_test(test, ('matches', self._path, regex))
- def search(self, regex: str, flags: int=0) ->QueryInstance:
+ def search(self, regex: str, flags: int = 0) -> QueryInstance:
"""
Run a regex test against a dict value (only substring string has to
match).
@@ -278,9 +362,16 @@ class Query(QueryInstance):
:param regex: The regular expression to use for matching
:param flags: regex flags to pass to ``re.match``
"""
- pass
- def test(self, func: Callable[[Mapping], bool], *args) ->QueryInstance:
+ def test(value):
+ if not isinstance(value, str):
+ return False
+
+ return re.search(regex, value, flags) is not None
+
+ return self._generate_test(test, ('search', self._path, regex))
+
+ def test(self, func: Callable[[Mapping], bool], *args) -> QueryInstance:
"""
Run a user-defined test function against a dict value.
@@ -300,9 +391,12 @@ class Query(QueryInstance):
argument
:param args: Additional arguments to pass to the test function
"""
- pass
+ return self._generate_test(
+ lambda value: func(value, *args),
+ ('test', self._path, func, args)
+ )
- def any(self, cond: Union[QueryInstance, List[Any]]) ->QueryInstance:
+ def any(self, cond: Union[QueryInstance, List[Any]]) -> QueryInstance:
"""
Check if a condition is met by any document in a list,
where a condition can also be a sequence (e.g. list).
@@ -324,9 +418,20 @@ class Query(QueryInstance):
a list of which at least one document has to be contained
in the tested document.
"""
- pass
+ if callable(cond):
+ def test(value):
+ return is_sequence(value) and any(cond(e) for e in value)
+
+ else:
+ def test(value):
+ return is_sequence(value) and any(e in cond for e in value)
- def all(self, cond: Union['QueryInstance', List[Any]]) ->QueryInstance:
+ return self._generate_test(
+ lambda value: test(value),
+ ('any', self._path, freeze(cond))
+ )
+
+ def all(self, cond: Union['QueryInstance', List[Any]]) -> QueryInstance:
"""
Check if a condition is met by all documents in a list,
where a condition can also be a sequence (e.g. list).
@@ -346,9 +451,20 @@ class Query(QueryInstance):
:param cond: Either a query that all documents have to match or a list
which has to be contained in the tested document.
"""
- pass
+ if callable(cond):
+ def test(value):
+ return is_sequence(value) and all(cond(e) for e in value)
+
+ else:
+ def test(value):
+ return is_sequence(value) and all(e in value for e in cond)
+
+ return self._generate_test(
+ lambda value: test(value),
+ ('all', self._path, freeze(cond))
+ )
- def one_of(self, items: List[Any]) ->QueryInstance:
+ def one_of(self, items: List[Any]) -> QueryInstance:
"""
Check if the value is contained in a list or generator.
@@ -356,26 +472,55 @@ class Query(QueryInstance):
:param items: The list of items to check with
"""
- pass
+ return self._generate_test(
+ lambda value: value in items,
+ ('one_of', self._path, freeze(items))
+ )
- def noop(self) ->QueryInstance:
+ def fragment(self, document: Mapping) -> QueryInstance:
+ def test(value):
+ for key in document:
+ if key not in value or value[key] != document[key]:
+ return False
+
+ return True
+
+ return self._generate_test(
+ lambda value: test(value),
+ ('fragment', freeze(document)),
+ allow_empty_path=True
+ )
+
+ def noop(self) -> QueryInstance:
"""
Always evaluate to ``True``.
Useful for having a base value when composing queries dynamically.
"""
- pass
- def map(self, fn: Callable[[Any], Any]) ->'Query':
+ return QueryInstance(
+ lambda value: True,
+ ()
+ )
+
+ def map(self, fn: Callable[[Any], Any]) -> 'Query':
"""
Add a function to the query path. Similar to __getattr__ but for
arbitrary functions.
"""
- pass
+ query = type(self)()
+
+ # Now we add the callable to the query path ...
+ query._path = self._path + (fn,)
+ # ... and kill the hash - callable objects can be mutable, so it's
+ # harmful to cache their results.
+ query._hash = None
+
+ return query
-def where(key: str) ->Query:
+def where(key: str) -> Query:
"""
A shorthand for ``Query()[key]``
"""
- pass
+ return Query()[key]
diff --git a/tinydb/storages.py b/tinydb/storages.py
index 0ddc223..d5a2db7 100644
--- a/tinydb/storages.py
+++ b/tinydb/storages.py
@@ -2,13 +2,15 @@
Contains the :class:`base class <tinydb.storages.Storage>` for storages and
implementations.
"""
+
import io
import json
import os
import warnings
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional
-__all__ = 'Storage', 'JSONStorage', 'MemoryStorage'
+
+__all__ = ('Storage', 'JSONStorage', 'MemoryStorage')
def touch(path: str, create_dirs: bool):
@@ -18,7 +20,17 @@ def touch(path: str, create_dirs: bool):
:param path: The file to create.
:param create_dirs: Whether to create all missing parent directories.
"""
- pass
+ if create_dirs:
+ base_dir = os.path.dirname(path)
+
+ # Check if we need to create missing parent directories
+ if not os.path.exists(base_dir):
+ os.makedirs(base_dir)
+
+ # Create the file by opening it in 'a' mode which creates the file if it
+ # does not exist yet but does not modify its contents
+ with open(path, 'a'):
+ pass
class Storage(ABC):
@@ -29,8 +41,11 @@ class Storage(ABC):
some place (memory, file on disk, ...).
"""
+ # Using ABCMeta as metaclass allows instantiating only storages that have
+ # implemented read and write
+
@abstractmethod
- def read(self) ->Optional[Dict[str, Dict[str, Any]]]:
+ def read(self) -> Optional[Dict[str, Dict[str, Any]]]:
"""
Read the current state.
@@ -38,10 +53,11 @@ class Storage(ABC):
Return ``None`` here to indicate that the storage is empty.
"""
- pass
+
+ raise NotImplementedError('To be overridden!')
@abstractmethod
- def write(self, data: Dict[str, Dict[str, Any]]) ->None:
+ def write(self, data: Dict[str, Dict[str, Any]]) -> None:
"""
Write the current state of the database to the storage.
@@ -49,12 +65,14 @@ class Storage(ABC):
:param data: The current state of the database.
"""
- pass
- def close(self) ->None:
+ raise NotImplementedError('To be overridden!')
+
+ def close(self) -> None:
"""
Optional: Close open file handles, etc.
"""
+
pass
@@ -63,8 +81,7 @@ class JSONStorage(Storage):
Store the data in a JSON file.
"""
- def __init__(self, path: str, create_dirs=False, encoding=None,
- access_mode='r+', **kwargs):
+ def __init__(self, path: str, create_dirs=False, encoding=None, access_mode='r+', **kwargs):
"""
Create a new instance.
@@ -78,17 +95,67 @@ class JSONStorage(Storage):
:param access_mode: mode in which the file is opened (r, r+)
:type access_mode: str
"""
+
super().__init__()
+
self._mode = access_mode
self.kwargs = kwargs
+
if access_mode not in ('r', 'rb', 'r+', 'rb+'):
warnings.warn(
- "Using an `access_mode` other than 'r', 'rb', 'r+' or 'rb+' can cause data loss or corruption"
- )
- if any([(character in self._mode) for character in ('+', 'w', 'a')]):
+ 'Using an `access_mode` other than \'r\', \'rb\', \'r+\' '
+ 'or \'rb+\' can cause data loss or corruption'
+ )
+
+ # Create the file if it doesn't exist and creating is allowed by the
+ # access mode
+ if any([character in self._mode for character in ('+', 'w', 'a')]): # any of the writing modes
touch(path, create_dirs=create_dirs)
+
+ # Open the file for reading/writing
self._handle = open(path, mode=self._mode, encoding=encoding)
+ def close(self) -> None:
+ self._handle.close()
+
+ def read(self) -> Optional[Dict[str, Dict[str, Any]]]:
+ # Get the file size by moving the cursor to the file end and reading
+ # its location
+ self._handle.seek(0, os.SEEK_END)
+ size = self._handle.tell()
+
+ if not size:
+ # File is empty, so we return ``None`` so TinyDB can properly
+ # initialize the database
+ return None
+ else:
+ # Return the cursor to the beginning of the file
+ self._handle.seek(0)
+
+ # Load the JSON contents of the file
+ return json.load(self._handle)
+
+ def write(self, data: Dict[str, Dict[str, Any]]):
+ # Move the cursor to the beginning of the file just in case
+ self._handle.seek(0)
+
+ # Serialize the database state using the user-provided arguments
+ serialized = json.dumps(data, **self.kwargs)
+
+ # Write the serialized data to the file
+ try:
+ self._handle.write(serialized)
+ except io.UnsupportedOperation:
+ raise IOError('Cannot write to the database. Access mode is "{0}"'.format(self._mode))
+
+ # Ensure the file has been written
+ self._handle.flush()
+ os.fsync(self._handle.fileno())
+
+ # Remove data that is behind the new cursor in case the file has
+ # gotten shorter
+ self._handle.truncate()
+
class MemoryStorage(Storage):
"""
@@ -99,5 +166,12 @@ class MemoryStorage(Storage):
"""
Create a new instance.
"""
+
super().__init__()
self.memory = None
+
+ def read(self) -> Optional[Dict[str, Dict[str, Any]]]:
+ return self.memory
+
+ def write(self, data: Dict[str, Dict[str, Any]]):
+ self.memory = data
diff --git a/tinydb/table.py b/tinydb/table.py
index 48eea63..60a8798 100644
--- a/tinydb/table.py
+++ b/tinydb/table.py
@@ -2,11 +2,25 @@
This module implements tables, the central place for accessing and manipulating
data in TinyDB.
"""
-from typing import Callable, Dict, Iterable, Iterator, List, Mapping, Optional, Union, cast, Tuple
+
+from typing import (
+ Callable,
+ Dict,
+ Iterable,
+ Iterator,
+ List,
+ Mapping,
+ Optional,
+ Union,
+ cast,
+ Tuple
+)
+
from .queries import QueryLike
from .storages import Storage
from .utils import LRUCache
-__all__ = 'Document', 'Table'
+
+__all__ = ('Document', 'Table')
class Document(dict):
@@ -59,79 +73,215 @@ class Table:
:param name: The table name
:param cache_size: Maximum capacity of query cache
"""
+
+ #: The class used to represent documents
+ #:
+ #: .. versionadded:: 4.0
document_class = Document
+
+ #: The class used to represent a document ID
+ #:
+ #: .. versionadded:: 4.0
document_id_class = int
+
+ #: The class used for caching query results
+ #:
+ #: .. versionadded:: 4.0
query_cache_class = LRUCache
+
+ #: The default capacity of the query cache
+ #:
+ #: .. versionadded:: 4.0
default_query_cache_capacity = 10
- def __init__(self, storage: Storage, name: str, cache_size: int=
- default_query_cache_capacity):
+ def __init__(
+ self,
+ storage: Storage,
+ name: str,
+ cache_size: int = default_query_cache_capacity
+ ):
"""
Create a table instance.
"""
+
self._storage = storage
self._name = name
- self._query_cache: LRUCache[QueryLike, List[Document]
- ] = self.query_cache_class(capacity=cache_size)
+ self._query_cache: LRUCache[QueryLike, List[Document]] \
+ = self.query_cache_class(capacity=cache_size)
+
self._next_id = None
def __repr__(self):
- args = ['name={!r}'.format(self.name), 'total={}'.format(len(self)),
- 'storage={}'.format(self._storage)]
+ args = [
+ 'name={!r}'.format(self.name),
+ 'total={}'.format(len(self)),
+ 'storage={}'.format(self._storage),
+ ]
+
return '<{} {}>'.format(type(self).__name__, ', '.join(args))
@property
- def name(self) ->str:
+ def name(self) -> str:
"""
Get the table name.
"""
- pass
+ return self._name
@property
- def storage(self) ->Storage:
+ def storage(self) -> Storage:
"""
Get the table storage instance.
"""
- pass
+ return self._storage
- def insert(self, document: Mapping) ->int:
+ def insert(self, document: Mapping) -> int:
"""
Insert a new document into the table.
:param document: the document to insert
:returns: the inserted document's ID
"""
- pass
- def insert_multiple(self, documents: Iterable[Mapping]) ->List[int]:
+ # Make sure the document implements the ``Mapping`` interface
+ if not isinstance(document, Mapping):
+ raise ValueError('Document is not a Mapping')
+
+ # First, we get the document ID for the new document
+ if isinstance(document, Document):
+ # For a `Document` object we use the specified ID
+ doc_id = document.doc_id
+
+ # We also reset the stored next ID so the next insert won't
+ # re-use document IDs by accident when storing an old value
+ self._next_id = None
+ else:
+ # In all other cases we use the next free ID
+ doc_id = self._get_next_id()
+
+ # Now, we update the table and add the document
+ def updater(table: dict):
+ if doc_id in table:
+ raise ValueError(f'Document with ID {str(doc_id)} '
+ f'already exists')
+
+ # By calling ``dict(document)`` we convert the data we got to a
+ # ``dict`` instance even if it was a different class that
+ # implemented the ``Mapping`` interface
+ table[doc_id] = dict(document)
+
+ # See below for details on ``Table._update``
+ self._update_table(updater)
+
+ return doc_id
+
+ def insert_multiple(self, documents: Iterable[Mapping]) -> List[int]:
"""
Insert multiple documents into the table.
:param documents: an Iterable of documents to insert
:returns: a list containing the inserted documents' IDs
"""
- pass
+ doc_ids = []
+
+ def updater(table: dict):
+ for document in documents:
+
+ # Make sure the document implements the ``Mapping`` interface
+ if not isinstance(document, Mapping):
+ raise ValueError('Document is not a Mapping')
+
+ if isinstance(document, Document):
+ # Check if document does not override an existing document
+ if document.doc_id in table:
+ raise ValueError(
+ f'Document with ID {str(document.doc_id)} '
+ f'already exists'
+ )
+
+ # Store the doc_id, so we can return all document IDs
+ # later. Then save the document with its doc_id and
+ # skip the rest of the current loop
+ doc_id = document.doc_id
+ doc_ids.append(doc_id)
+ table[doc_id] = dict(document)
+ continue
+
+ # Generate new document ID for this document
+ # Store the doc_id, so we can return all document IDs
+ # later, then save the document with the new doc_id
+ doc_id = self._get_next_id()
+ doc_ids.append(doc_id)
+ table[doc_id] = dict(document)
- def all(self) ->List[Document]:
+ # See below for details on ``Table._update``
+ self._update_table(updater)
+
+ return doc_ids
+
+ def all(self) -> List[Document]:
"""
Get all documents stored in the table.
:returns: a list with all documents.
"""
- pass
- def search(self, cond: QueryLike) ->List[Document]:
+ # iter(self) (implemented in Table.__iter__ provides an iterator
+ # that returns all documents in this table. We use it to get a list
+ # of all documents by using the ``list`` constructor to perform the
+ # conversion.
+
+ return list(iter(self))
+
+ def search(self, cond: QueryLike) -> List[Document]:
"""
Search for all documents matching a 'where' cond.
:param cond: the condition to check against
:returns: list of matching documents
"""
- pass
- def get(self, cond: Optional[QueryLike]=None, doc_id: Optional[int]=
- None, doc_ids: Optional[List]=None) ->Optional[Union[Document, List
- [Document]]]:
+ # First, we check the query cache to see if it has results for this
+ # query
+ cached_results = self._query_cache.get(cond)
+ if cached_results is not None:
+ return cached_results[:]
+
+ # Perform the search by applying the query to all documents.
+ # Then, only if the document matches the query, convert it
+ # to the document class and document ID class.
+ docs = [
+ self.document_class(doc, self.document_id_class(doc_id))
+ for doc_id, doc in self._read_table().items()
+ if cond(doc)
+ ]
+
+ # Only cache cacheable queries.
+ #
+ # This weird `getattr` dance is needed to make MyPy happy as
+ # it doesn't know that a query might have a `is_cacheable` method
+ # that is not declared in the `QueryLike` protocol due to it being
+ # optional.
+ # See: https://github.com/python/mypy/issues/1424
+ #
+ # Note also that by default we expect custom query objects to be
+ # cacheable (which means they need to have a stable hash value).
+ # This is to keep consistency with TinyDB's behavior before
+ # `is_cacheable` was introduced which assumed that all queries
+ # are cacheable.
+ is_cacheable: Callable[[], bool] = getattr(cond, 'is_cacheable',
+ lambda: True)
+ if is_cacheable():
+ # Update the query cache
+ self._query_cache[cond] = docs[:]
+
+ return docs
+
+ def get(
+ self,
+ cond: Optional[QueryLike] = None,
+ doc_id: Optional[int] = None,
+ doc_ids: Optional[List] = None
+ ) -> Optional[Union[Document, List[Document]]]:
"""
Get exactly one document specified by a query or a document ID.
However, if multiple document IDs are given then returns all
@@ -145,10 +295,55 @@ class Table:
:returns: the document(s) or ``None``
"""
- pass
+ table = self._read_table()
+
+ if doc_id is not None:
+ # Retrieve a document specified by its ID
+ raw_doc = table.get(str(doc_id), None)
+
+ if raw_doc is None:
+ return None
- def contains(self, cond: Optional[QueryLike]=None, doc_id: Optional[int
- ]=None) ->bool:
+ # Convert the raw data to the document class
+ return self.document_class(raw_doc, doc_id)
+
+ elif doc_ids is not None:
+ # Filter the table by extracting out all those documents which
+ # have doc id specified in the doc_id list.
+
+ # Since document IDs will be unique, we make it a set to ensure
+ # constant time lookup
+ doc_ids_set = set(str(doc_id) for doc_id in doc_ids)
+
+ # Now return the filtered documents in form of list
+ return [
+ self.document_class(doc, self.document_id_class(doc_id))
+ for doc_id, doc in table.items()
+ if doc_id in doc_ids_set
+ ]
+
+ elif cond is not None:
+ # Find a document specified by a query
+ # The trailing underscore in doc_id_ is needed so MyPy
+ # doesn't think that `doc_id_` (which is a string) needs
+ # to have the same type as `doc_id` which is this function's
+ # parameter and is an optional `int`.
+ for doc_id_, doc in self._read_table().items():
+ if cond(doc):
+ return self.document_class(
+ doc,
+ self.document_id_class(doc_id_)
+ )
+
+ return None
+
+ raise RuntimeError('You have to pass either cond or doc_id or doc_ids')
+
+ def contains(
+ self,
+ cond: Optional[QueryLike] = None,
+ doc_id: Optional[int] = None
+ ) -> bool:
"""
Check whether the database contains a document matching a query or
an ID.
@@ -158,11 +353,22 @@ class Table:
:param cond: the condition use
:param doc_id: the document ID to look for
"""
- pass
+ if doc_id is not None:
+ # Documents specified by ID
+ return self.get(doc_id=doc_id) is not None
- def update(self, fields: Union[Mapping, Callable[[Mapping], None]],
- cond: Optional[QueryLike]=None, doc_ids: Optional[Iterable[int]]=None
- ) ->List[int]:
+ elif cond is not None:
+ # Document specified by condition
+ return self.get(cond) is not None
+
+ raise RuntimeError('You have to pass either cond or doc_id')
+
+ def update(
+ self,
+ fields: Union[Mapping, Callable[[Mapping], None]],
+ cond: Optional[QueryLike] = None,
+ doc_ids: Optional[Iterable[int]] = None,
+ ) -> List[int]:
"""
Update all matching documents to have a given set of fields.
@@ -172,19 +378,135 @@ class Table:
:param doc_ids: a list of document IDs
:returns: a list containing the updated document's ID
"""
- pass
- def update_multiple(self, updates: Iterable[Tuple[Union[Mapping,
- Callable[[Mapping], None]], QueryLike]]) ->List[int]:
+ # Define the function that will perform the update
+ if callable(fields):
+ def perform_update(table, doc_id):
+ # Update documents by calling the update function provided by
+ # the user
+ fields(table[doc_id])
+ else:
+ def perform_update(table, doc_id):
+ # Update documents by setting all fields from the provided data
+ table[doc_id].update(fields)
+
+ if doc_ids is not None:
+ # Perform the update operation for documents specified by a list
+ # of document IDs
+
+ updated_ids = list(doc_ids)
+
+ def updater(table: dict):
+ # Call the processing callback with all document IDs
+ for doc_id in updated_ids:
+ perform_update(table, doc_id)
+
+ # Perform the update operation (see _update_table for details)
+ self._update_table(updater)
+
+ return updated_ids
+
+ elif cond is not None:
+ # Perform the update operation for documents specified by a query
+
+ # Collect affected doc_ids
+ updated_ids = []
+
+ def updater(table: dict):
+ _cond = cast(QueryLike, cond)
+
+ # We need to convert the keys iterator to a list because
+ # we may remove entries from the ``table`` dict during
+ # iteration and doing this without the list conversion would
+ # result in an exception (RuntimeError: dictionary changed size
+ # during iteration)
+ for doc_id in list(table.keys()):
+ # Pass through all documents to find documents matching the
+ # query. Call the processing callback with the document ID
+ if _cond(table[doc_id]):
+ # Add ID to list of updated documents
+ updated_ids.append(doc_id)
+
+ # Perform the update (see above)
+ perform_update(table, doc_id)
+
+ # Perform the update operation (see _update_table for details)
+ self._update_table(updater)
+
+ return updated_ids
+
+ else:
+ # Update all documents unconditionally
+
+ updated_ids = []
+
+ def updater(table: dict):
+ # Process all documents
+ for doc_id in list(table.keys()):
+ # Add ID to list of updated documents
+ updated_ids.append(doc_id)
+
+ # Perform the update (see above)
+ perform_update(table, doc_id)
+
+ # Perform the update operation (see _update_table for details)
+ self._update_table(updater)
+
+ return updated_ids
+
+ def update_multiple(
+ self,
+ updates: Iterable[
+ Tuple[Union[Mapping, Callable[[Mapping], None]], QueryLike]
+ ],
+ ) -> List[int]:
"""
Update all matching documents to have a given set of fields.
:returns: a list containing the updated document's ID
"""
- pass
- def upsert(self, document: Mapping, cond: Optional[QueryLike]=None) ->List[
- int]:
+ # Define the function that will perform the update
+ def perform_update(fields, table, doc_id):
+ if callable(fields):
+ # Update documents by calling the update function provided
+ # by the user
+ fields(table[doc_id])
+ else:
+ # Update documents by setting all fields from the provided
+ # data
+ table[doc_id].update(fields)
+
+ # Perform the update operation for documents specified by a query
+
+ # Collect affected doc_ids
+ updated_ids = []
+
+ def updater(table: dict):
+ # We need to convert the keys iterator to a list because
+ # we may remove entries from the ``table`` dict during
+ # iteration and doing this without the list conversion would
+ # result in an exception (RuntimeError: dictionary changed size
+ # during iteration)
+ for doc_id in list(table.keys()):
+ for fields, cond in updates:
+ _cond = cast(QueryLike, cond)
+
+ # Pass through all documents to find documents matching the
+ # query. Call the processing callback with the document ID
+ if _cond(table[doc_id]):
+ # Add ID to list of updated documents
+ updated_ids.append(doc_id)
+
+ # Perform the update (see above)
+ perform_update(fields, table, doc_id)
+
+ # Perform the update operation (see _update_table for details)
+ self._update_table(updater)
+
+ return updated_ids
+
+ def upsert(self, document: Mapping, cond: Optional[QueryLike] = None) -> List[int]:
"""
Update documents, if they exist, insert them otherwise.
@@ -197,10 +519,39 @@ class Table:
Document with a doc_id
:returns: a list containing the updated documents' IDs
"""
- pass
- def remove(self, cond: Optional[QueryLike]=None, doc_ids: Optional[
- Iterable[int]]=None) ->List[int]:
+ # Extract doc_id
+ if isinstance(document, Document) and hasattr(document, 'doc_id'):
+ doc_ids: Optional[List[int]] = [document.doc_id]
+ else:
+ doc_ids = None
+
+ # Make sure we can actually find a matching document
+ if doc_ids is None and cond is None:
+ raise ValueError("If you don't specify a search query, you must "
+ "specify a doc_id. Hint: use a table.Document "
+ "object.")
+
+ # Perform the update operation
+ try:
+ updated_docs: Optional[List[int]] = self.update(document, cond, doc_ids)
+ except KeyError:
+ # This happens when a doc_id is specified, but it's missing
+ updated_docs = None
+
+ # If documents have been updated: return their IDs
+ if updated_docs:
+ return updated_docs
+
+ # There are no documents that match the specified query -> insert the
+ # data as a new document
+ return [self.insert(document)]
+
+ def remove(
+ self,
+ cond: Optional[QueryLike] = None,
+ doc_ids: Optional[Iterable[int]] = None,
+ ) -> List[int]:
"""
Remove all matching documents.
@@ -208,50 +559,139 @@ class Table:
:param doc_ids: a list of document IDs
:returns: a list containing the removed documents' ID
"""
- pass
+ if doc_ids is not None:
+ # This function returns the list of IDs for the documents that have
+ # been removed. When removing documents identified by a set of
+ # document IDs, it's this list of document IDs we need to return
+ # later.
+ # We convert the document ID iterator into a list, so we can both
+ # use the document IDs to remove the specified documents and
+ # to return the list of affected document IDs
+ removed_ids = list(doc_ids)
+
+ def updater(table: dict):
+ for doc_id in removed_ids:
+ table.pop(doc_id)
+
+ # Perform the remove operation
+ self._update_table(updater)
+
+ return removed_ids
+
+ if cond is not None:
+ removed_ids = []
+
+ # This updater function will be called with the table data
+ # as its first argument. See ``Table._update`` for details on this
+ # operation
+ def updater(table: dict):
+ # We need to convince MyPy (the static type checker) that
+ # the ``cond is not None`` invariant still holds true when
+ # the updater function is called
+ _cond = cast(QueryLike, cond)
+
+ # We need to convert the keys iterator to a list because
+ # we may remove entries from the ``table`` dict during
+ # iteration and doing this without the list conversion would
+ # result in an exception (RuntimeError: dictionary changed size
+ # during iteration)
+ for doc_id in list(table.keys()):
+ if _cond(table[doc_id]):
+ # Add document ID to list of removed document IDs
+ removed_ids.append(doc_id)
- def truncate(self) ->None:
+ # Remove document from the table
+ table.pop(doc_id)
+
+ # Perform the remove operation
+ self._update_table(updater)
+
+ return removed_ids
+
+ raise RuntimeError('Use truncate() to remove all documents')
+
+ def truncate(self) -> None:
"""
Truncate the table by removing all documents.
"""
- pass
- def count(self, cond: QueryLike) ->int:
+ # Update the table by resetting all data
+ self._update_table(lambda table: table.clear())
+
+ # Reset document ID counter
+ self._next_id = None
+
+ def count(self, cond: QueryLike) -> int:
"""
Count the documents matching a query.
:param cond: the condition use
"""
- pass
- def clear_cache(self) ->None:
+ return len(self.search(cond))
+
+ def clear_cache(self) -> None:
"""
Clear the query cache.
"""
- pass
+
+ self._query_cache.clear()
def __len__(self):
"""
Count the total number of documents in this table.
"""
+
return len(self._read_table())
- def __iter__(self) ->Iterator[Document]:
+ def __iter__(self) -> Iterator[Document]:
"""
Iterate over all documents stored in the table.
:returns: an iterator over all documents.
"""
+
+ # Iterate all documents and their IDs
for doc_id, doc in self._read_table().items():
+ # Convert documents to the document class
yield self.document_class(doc, self.document_id_class(doc_id))
def _get_next_id(self):
"""
Return the ID for a newly inserted document.
"""
- pass
- def _read_table(self) ->Dict[str, Mapping]:
+ # If we already know the next ID
+ if self._next_id is not None:
+ next_id = self._next_id
+ self._next_id = next_id + 1
+
+ return next_id
+
+ # Determine the next document ID by finding out the max ID value
+ # of the current table documents
+
+ # Read the table documents
+ table = self._read_table()
+
+ # If the table is empty, set the initial ID
+ if not table:
+ next_id = 1
+ self._next_id = next_id + 1
+
+ return next_id
+
+ # Determine the next ID based on the maximum ID that's currently in use
+ max_id = max(self.document_id_class(i) for i in table.keys())
+ next_id = max_id + 1
+
+ # The next ID we will return AFTER this call needs to be larger than
+ # the current next ID we calculated
+ self._next_id = next_id + 1
+
+ return next_id
+
+ def _read_table(self) -> Dict[str, Mapping]:
"""
Read the table data from the underlying storage.
@@ -259,7 +699,22 @@ class Table:
we may not want to convert *all* documents when returning
only one document for example.
"""
- pass
+
+ # Retrieve the tables from the storage
+ tables = self._storage.read()
+
+ if tables is None:
+ # The database is empty
+ return {}
+
+ # Retrieve the current table's data
+ try:
+ table = tables[self.name]
+ except KeyError:
+ # The table does not exist yet, so it is empty
+ return {}
+
+ return table
def _update_table(self, updater: Callable[[Dict[int, Mapping]], None]):
"""
@@ -274,4 +729,41 @@ class Table:
As a further optimization, we don't convert the documents into the
document class, as the table data will *not* be returned to the user.
"""
- pass
+
+ tables = self._storage.read()
+
+ if tables is None:
+ # The database is empty
+ tables = {}
+
+ try:
+ raw_table = tables[self.name]
+ except KeyError:
+ # The table does not exist yet, so it is empty
+ raw_table = {}
+
+ # Convert the document IDs to the document ID class.
+ # This is required as the rest of TinyDB expects the document IDs
+ # to be an instance of ``self.document_id_class`` but the storage
+ # might convert dict keys to strings.
+ table = {
+ self.document_id_class(doc_id): doc
+ for doc_id, doc in raw_table.items()
+ }
+
+ # Perform the table update operation
+ updater(table)
+
+ # Convert the document IDs back to strings.
+ # This is required as some storages (most notably the JSON file format)
+ # don't support IDs other than strings.
+ tables[self.name] = {
+ str(doc_id): doc
+ for doc_id, doc in table.items()
+ }
+
+ # Write the newly updated data back to the storage
+ self._storage.write(tables)
+
+ # Clear the query cache, as the table contents have changed
+ self.clear_cache()
diff --git a/tinydb/utils.py b/tinydb/utils.py
index 0721622..08430ba 100644
--- a/tinydb/utils.py
+++ b/tinydb/utils.py
@@ -1,13 +1,17 @@
"""
Utility functions.
"""
+
from collections import OrderedDict, abc
-from typing import List, Iterator, TypeVar, Generic, Union, Optional, Type, TYPE_CHECKING
+from typing import List, Iterator, TypeVar, Generic, Union, Optional, Type, \
+ TYPE_CHECKING
+
K = TypeVar('K')
V = TypeVar('V')
D = TypeVar('D')
T = TypeVar('T')
-__all__ = 'LRUCache', 'freeze', 'with_typehint'
+
+__all__ = ('LRUCache', 'freeze', 'with_typehint')
def with_typehint(baseclass: Type[T]):
@@ -23,7 +27,13 @@ def with_typehint(baseclass: Type[T]):
MyPy does not. For that reason TinyDB has a MyPy plugin in
``mypy_plugin.py`` that adds support for this pattern.
"""
- pass
+ if TYPE_CHECKING:
+ # In the case of type checking: pretend that the target class inherits
+ # from the specified base class
+ return baseclass
+
+ # Otherwise: just inherit from `object` like a regular Python class
+ return object
class LRUCache(abc.MutableMapping, Generic[K, V]):
@@ -40,31 +50,66 @@ class LRUCache(abc.MutableMapping, Generic[K, V]):
be discarded.
"""
- def __init__(self, capacity=None) ->None:
+ def __init__(self, capacity=None) -> None:
self.capacity = capacity
self.cache: OrderedDict[K, V] = OrderedDict()
- def __len__(self) ->int:
+ @property
+ def lru(self) -> List[K]:
+ return list(self.cache.keys())
+
+ @property
+ def length(self) -> int:
+ return len(self.cache)
+
+ def clear(self) -> None:
+ self.cache.clear()
+
+ def __len__(self) -> int:
return self.length
- def __contains__(self, key: object) ->bool:
+ def __contains__(self, key: object) -> bool:
return key in self.cache
- def __setitem__(self, key: K, value: V) ->None:
+ def __setitem__(self, key: K, value: V) -> None:
self.set(key, value)
- def __delitem__(self, key: K) ->None:
+ def __delitem__(self, key: K) -> None:
del self.cache[key]
- def __getitem__(self, key) ->V:
+ def __getitem__(self, key) -> V:
value = self.get(key)
if value is None:
raise KeyError(key)
+
return value
- def __iter__(self) ->Iterator[K]:
+ def __iter__(self) -> Iterator[K]:
return iter(self.cache)
+ def get(self, key: K, default: Optional[D] = None) -> Optional[Union[V, D]]:
+ value = self.cache.get(key)
+
+ if value is not None:
+ self.cache.move_to_end(key, last=True)
+
+ return value
+
+ return default
+
+ def set(self, key: K, value: V):
+ if self.cache.get(key):
+ self.cache.move_to_end(key, last=True)
+
+ else:
+ self.cache[key] = value
+
+ # Check, if the cache is full and we have to remove old items
+ # If the queue is of unlimited size, self.capacity is NaN and
+ # x > NaN is always False in Python and the cache won't be cleared.
+ if self.capacity is not None and self.length > self.capacity:
+ self.cache.popitem(last=False)
+
class FrozenDict(dict):
"""
@@ -76,16 +121,39 @@ class FrozenDict(dict):
"""
def __hash__(self):
+ # Calculate the has by hashing a tuple of all dict items
return hash(tuple(sorted(self.items())))
+
+ def _immutable(self, *args, **kws):
+ raise TypeError('object is immutable')
+
+ # Disable write access to the dict
__setitem__ = _immutable
__delitem__ = _immutable
clear = _immutable
- setdefault = _immutable
+ setdefault = _immutable # type: ignore
popitem = _immutable
+ def update(self, e=None, **f):
+ raise TypeError('object is immutable')
+
+ def pop(self, k, d=None):
+ raise TypeError('object is immutable')
+
def freeze(obj):
"""
Freeze an object by making it immutable and thus hashable.
"""
- pass
+ if isinstance(obj, dict):
+ # Transform dicts into ``FrozenDict``s
+ return FrozenDict((k, freeze(v)) for k, v in obj.items())
+ elif isinstance(obj, list):
+ # Transform lists into tuples
+ return tuple(freeze(el) for el in obj)
+ elif isinstance(obj, set):
+ # Transform sets into ``frozenset``s
+ return frozenset(obj)
+ else:
+ # Don't handle all other objects
+ return obj