diff --git a/tinydb/database.py b/tinydb/database.py
index a9b6c89..149806b 100644
--- a/tinydb/database.py
+++ b/tinydb/database.py
@@ -94,7 +94,9 @@ class TinyDB(TableBase):
:param name: The name of the table.
:param kwargs: Keyword arguments to pass to the table class constructor
"""
- pass
+ if name not in self._tables:
+ self._tables[name] = self.table_class(name, self._storage, **kwargs)
+ return self._tables[name]
def tables(self) -> Set[str]:
"""
@@ -102,13 +104,14 @@ class TinyDB(TableBase):
:returns: a set of table names
"""
- pass
+ return set(self._tables.keys()) | set(self._storage.read().keys())
def drop_tables(self) -> None:
"""
Drop all tables from the database. **CANNOT BE REVERSED!**
"""
- pass
+ self._tables.clear()
+ self._storage.write({})
def drop_table(self, name: str) -> None:
"""
@@ -116,7 +119,12 @@ class TinyDB(TableBase):
:param name: The name of the table to drop.
"""
- pass
+ if name in self._tables:
+ del self._tables[name]
+ data = self._storage.read()
+ if name in data:
+ del data[name]
+ self._storage.write(data)
@property
def storage(self) -> Storage:
@@ -126,7 +134,7 @@ class TinyDB(TableBase):
:return: This instance's storage
:rtype: Storage
"""
- pass
+ return self._storage
def close(self) -> None:
"""
@@ -143,7 +151,8 @@ class TinyDB(TableBase):
Upon leaving this context, the ``close`` method will be called.
"""
- pass
+ self._storage.close()
+ self._opened = False
def __enter__(self):
"""
@@ -184,4 +193,4 @@ class TinyDB(TableBase):
"""
Return an iterator for the default table's documents.
"""
- return iter(self.table(self.default_table_name))
\ No newline at end of file
+ return iter(self.table(self.default_table_name))
diff --git a/tinydb/middlewares.py b/tinydb/middlewares.py
index ba9ac98..9a93123 100644
--- a/tinydb/middlewares.py
+++ b/tinydb/middlewares.py
@@ -86,4 +86,6 @@ class CachingMiddleware(Middleware):
"""
Flush all unwritten data to disk.
"""
- pass
\ No newline at end of file
+ if self.cache is not None:
+ self.storage.write(self.cache)
+ self._cache_modified_count = 0
diff --git a/tinydb/operations.py b/tinydb/operations.py
index dcf2ff7..3c5d89f 100644
--- a/tinydb/operations.py
+++ b/tinydb/operations.py
@@ -18,28 +18,39 @@ def add(field, n):
"""
Add ``n`` to a given field in the document.
"""
- pass
+ def transform(doc):
+ if field in doc:
+ doc[field] += n
+ return doc
+ return transform
def subtract(field, n):
"""
- Subtract ``n`` to a given field in the document.
+ Subtract ``n`` from a given field in the document.
"""
- pass
+ def transform(doc):
+ if field in doc:
+ doc[field] -= n
+ return doc
+ return transform
def set(field, val):
"""
Set a given field to ``val``.
"""
- pass
+ def transform(doc):
+ doc[field] = val
+ return doc
+ return transform
def increment(field):
"""
Increment a given field in the document by 1.
"""
- pass
+ return add(field, 1)
def decrement(field):
"""
Decrement a given field in the document by 1.
"""
- pass
\ No newline at end of file
+ return subtract(field, 1)
diff --git a/tinydb/queries.py b/tinydb/queries.py
index 78e7e99..cabafb2 100644
--- a/tinydb/queries.py
+++ b/tinydb/queries.py
@@ -174,7 +174,21 @@ class Query(QueryInstance):
:param hashval: The hash of the query.
:return: A :class:`~tinydb.queries.QueryInstance` object
"""
- pass
+ if not self._path and not allow_empty_path:
+ raise RuntimeError('Empty query was evaluated')
+
+ def runner(value):
+ try:
+ for part in self._path:
+ if isinstance(part, str):
+ value = value[part]
+ else:
+ value = part(value)
+ return test(value)
+ except (KeyError, TypeError, ValueError):
+ return False
+
+ return QueryInstance(runner, hashval)
def __eq__(self, rhs: Any):
"""
@@ -287,7 +301,8 @@ class Query(QueryInstance):
argument
:param args: Additional arguments to pass to the test function
"""
- pass
+ return self._generate_test(lambda value: func(value, *args),
+ ('test', self._path, func, args))
def any(self, cond: Union[QueryInstance, List[Any]]) -> QueryInstance:
"""
@@ -311,7 +326,13 @@ class Query(QueryInstance):
a list of which at least one document has to be contained
in the tested document.
"""
- pass
+ if isinstance(cond, QueryInstance):
+ def test(value):
+ return any(cond(item) for item in value)
+ else:
+ def test(value):
+ return any(item in cond for item in value)
+ return self._generate_test(test, ('any', self._path, freeze(cond)))
def all(self, cond: Union['QueryInstance', List[Any]]) -> QueryInstance:
"""
@@ -333,6 +354,13 @@ class Query(QueryInstance):
:param cond: Either a query that all documents have to match or a list
which has to be contained in the tested document.
"""
+ if isinstance(cond, QueryInstance):
+ def test(value):
+ return all(cond(item) for item in value)
+ else:
+ def test(value):
+ return all(item in value for item in cond)
+ return self._generate_test(test, ('all', self._path, freeze(cond)))
pass
def one_of(self, items: List[Any]) -> QueryInstance:
@@ -343,7 +371,8 @@ class Query(QueryInstance):
:param items: The list of items to check with
"""
- pass
+ return self._generate_test(lambda value: value in items,
+ ('one_of', self._path, freeze(items)))
def noop(self) -> QueryInstance:
"""
@@ -351,17 +380,20 @@ class Query(QueryInstance):
Useful for having a base value when composing queries dynamically.
"""
- pass
+ return self._generate_test(lambda value: True, ('noop', self._path))
def map(self, fn: Callable[[Any], Any]) -> 'Query':
"""
Add a function to the query path. Similar to __getattr__ but for
arbitrary functions.
"""
- pass
+ query = type(self)()
+ query._path = self._path + (fn,)
+ query._hash = ('map', query._path) if self.is_cacheable() else None
+ return query
def where(key: str) -> Query:
"""
A shorthand for ``Query()[key]``
"""
- pass
\ No newline at end of file
+ return Query()[key]
diff --git a/tinydb/storages.py b/tinydb/storages.py
index 86c0987..6c1fec9 100644
--- a/tinydb/storages.py
+++ b/tinydb/storages.py
@@ -92,5 +92,11 @@ class MemoryStorage(Storage):
"""
Create a new instance.
"""
- super().__init__()
- self.memory = None
\ No newline at end of file
+ self.memory = None
+
+ def read(self) -> Optional[Dict[str, Dict[str, Any]]]:
+ return self.memory
+
+ def write(self, data: Dict[str, Dict[str, Any]]) -> None:
+ self.memory = data
+
diff --git a/tinydb/table.py b/tinydb/table.py
index 5f0a160..6fea6fb 100644
--- a/tinydb/table.py
+++ b/tinydb/table.py
@@ -80,14 +80,14 @@ class Table:
"""
Get the table name.
"""
- pass
+ return self._name
@property
def storage(self) -> Storage:
"""
Get the table storage instance.
"""
- pass
+ return self._storage
def insert(self, document: Mapping) -> int:
"""
@@ -96,7 +96,10 @@ class Table:
:param document: the document to insert
:returns: the inserted document's ID
"""
- pass
+ doc_id = self._get_next_id()
+ self._storage.write(self._name, {doc_id: document})
+ self._query_cache.clear()
+ return doc_id
def insert_multiple(self, documents: Iterable[Mapping]) -> List[int]:
"""
@@ -105,7 +108,10 @@ class Table:
:param documents: an Iterable of documents to insert
:returns: a list containing the inserted documents' IDs
"""
- pass
+ doc_ids = []
+ for document in documents:
+ doc_ids.append(self.insert(document))
+ return doc_ids
def all(self) -> List[Document]:
"""
@@ -113,7 +119,7 @@ class Table:
:returns: a list with all documents.
"""
- pass
+ return [self.document_class(doc, doc_id) for doc_id, doc in self._read_table().items()]
def search(self, cond: QueryLike) -> List[Document]:
"""
@@ -122,7 +128,12 @@ class Table:
:param cond: the condition to check against
:returns: list of matching documents
"""
- pass
+ if cond in self._query_cache:
+ return self._query_cache[cond]
+
+ docs = [doc for doc in self.all() if cond(doc)]
+ self._query_cache[cond] = docs
+ return docs
def get(self, cond: Optional[QueryLike]=None, doc_id: Optional[int]=None, doc_ids: Optional[List]=None) -> Optional[Union[Document, List[Document]]]:
"""
@@ -138,7 +149,14 @@ class Table:
:returns: the document(s) or ``None``
"""
- pass
+ if doc_id is not None:
+ return self.document_class(self._read_table().get(doc_id, {}), doc_id) or None
+ elif doc_ids is not None:
+ return [self.document_class(self._read_table().get(id, {}), id) for id in doc_ids if id in self._read_table()]
+ elif cond is not None:
+ for doc in self.search(cond):
+ return doc
+ return None
def contains(self, cond: Optional[QueryLike]=None, doc_id: Optional[int]=None) -> bool:
"""
@@ -150,7 +168,11 @@ class Table:
:param cond: the condition use
:param doc_id: the document ID to look for
"""
- pass
+ if doc_id is not None:
+ return doc_id in self._read_table()
+ elif cond is not None:
+ return len(self.search(cond)) > 0
+ return False
def update(self, fields: Union[Mapping, Callable[[Mapping], None]], cond: Optional[QueryLike]=None, doc_ids: Optional[Iterable[int]]=None) -> List[int]:
"""
@@ -162,7 +184,17 @@ class Table:
:param doc_ids: a list of document IDs
:returns: a list containing the updated document's ID
"""
- pass
+ updated = []
+ if doc_ids is not None:
+ for doc_id in doc_ids:
+ if doc_id in self._read_table():
+ updated.append(doc_id)
+ self._update_document(doc_id, fields)
+ else:
+ for doc in self.search(cond):
+ updated.append(doc.doc_id)
+ self._update_document(doc.doc_id, fields)
+ return updated
def update_multiple(self, updates: Iterable[Tuple[Union[Mapping, Callable[[Mapping], None]], QueryLike]]) -> List[int]:
"""
@@ -170,7 +202,10 @@ class Table:
:returns: a list containing the updated document's ID
"""
- pass
+ updated = []
+ for fields, cond in updates:
+ updated.extend(self.update(fields, cond))
+ return updated
def upsert(self, document: Mapping, cond: Optional[QueryLike]=None) -> List[int]:
"""
@@ -185,7 +220,20 @@ class Table:
Document with a doc_id
:returns: a list containing the updated documents' IDs
"""
- pass
+ if isinstance(document, Document):
+ doc_id = document.doc_id
+ document = dict(document)
+ if doc_id is not None:
+ if self.contains(doc_id=doc_id):
+ self.update(document, doc_ids=[doc_id])
+ return [doc_id]
+ else:
+ return [self.insert(document)]
+
+ updated = self.update(document, cond)
+ if updated:
+ return updated
+ return [self.insert(document)]
def remove(self, cond: Optional[QueryLike]=None, doc_ids: Optional[Iterable[int]]=None) -> List[int]:
"""
@@ -195,13 +243,25 @@ class Table:
:param doc_ids: a list of document IDs
:returns: a list containing the removed documents' ID
"""
- pass
+ removed = []
+ if doc_ids is not None:
+ for doc_id in doc_ids:
+ if doc_id in self._read_table():
+ removed.append(doc_id)
+ del self._read_table()[doc_id]
+ elif cond is not None:
+ for doc in self.search(cond):
+ removed.append(doc.doc_id)
+ del self._read_table()[doc.doc_id]
+ self._write_table(self._read_table())
+ return removed
def truncate(self) -> None:
"""
Truncate the table by removing all documents.
"""
- pass
+ self._write_table({})
+ self._query_cache.clear()
def count(self, cond: QueryLike) -> int:
"""
@@ -209,13 +269,13 @@ class Table:
:param cond: the condition use
"""
- pass
+ return len(self.search(cond))
def clear_cache(self) -> None:
"""
Clear the query cache.
"""
- pass
+ self._query_cache.clear()
def __len__(self):
"""
@@ -236,7 +296,11 @@ class Table:
"""
Return the ID for a newly inserted document.
"""
- pass
+ if self._next_id is None:
+ self._next_id = max(self._read_table().keys() or [0]) + 1
+ else:
+ self._next_id += 1
+ return self._next_id
def _read_table(self) -> Dict[str, Mapping]:
"""
@@ -246,7 +310,7 @@ class Table:
we may not want to convert *all* documents when returning
only one document for example.
"""
- pass
+ return self._storage.read().get(self._name, {})
def _update_table(self, updater: Callable[[Dict[int, Mapping]], None]):
"""
@@ -261,4 +325,8 @@ class Table:
As a further optimization, we don't convert the documents into the
document class, as the table data will *not* be returned to the user.
"""
- pass
\ No newline at end of file
+ data = self._storage.read()
+ data[self._name] = self._read_table()
+ updater(data[self._name])
+ self._storage.write(data)
+ self._query_cache.clear()