Skip to content

back to Claude Sonnet 3.5 - Base summary

Claude Sonnet 3.5 - Base: tinydb

Failed to run pytests for test tests

ImportError while loading conftest '/testbed/tests/conftest.py'.
tests/conftest.py:7: in <module>
    from tinydb.middlewares import CachingMiddleware
tinydb/__init__.py:27: in <module>
    from .queries import Query, where
tinydb/queries.py:21: in <module>
    from .utils import freeze
tinydb/utils.py:84: in <module>
    class FrozenDict(dict):
tinydb/utils.py:95: in FrozenDict
    __setitem__ = _immutable
E   NameError: name '_immutable' is not defined

Patch diff

diff --git a/tinydb/database.py b/tinydb/database.py
index a4ce0e1..f05a307 100644
--- a/tinydb/database.py
+++ b/tinydb/database.py
@@ -99,7 +99,9 @@ class TinyDB(TableBase):
         :param name: The name of the table.
         :param kwargs: Keyword arguments to pass to the table class constructor
         """
-        pass
+        if name not in self._tables:
+            self._tables[name] = self.table_class(self._storage, name, **kwargs)
+        return self._tables[name]

     def tables(self) ->Set[str]:
         """
@@ -107,13 +109,14 @@ class TinyDB(TableBase):

         :returns: a set of table names
         """
-        pass
+        return set(self._storage.read().keys())

     def drop_tables(self) ->None:
         """
         Drop all tables from the database. **CANNOT BE REVERSED!**
         """
-        pass
+        self._storage.write({})
+        self._tables.clear()

     def drop_table(self, name: str) ->None:
         """
@@ -121,7 +124,13 @@ class TinyDB(TableBase):

         :param name: The name of the table to drop.
         """
-        pass
+        if name in self._tables:
+            del self._tables[name]
+        
+        data = self._storage.read()
+        if name in data:
+            del data[name]
+            self._storage.write(data)

     @property
     def storage(self) ->Storage:
@@ -131,7 +140,7 @@ class TinyDB(TableBase):
         :return: This instance's storage
         :rtype: Storage
         """
-        pass
+        return self._storage

     def close(self) ->None:
         """
@@ -148,7 +157,8 @@ class TinyDB(TableBase):

         Upon leaving this context, the ``close`` method will be called.
         """
-        pass
+        self._storage.close()
+        self._opened = False

     def __enter__(self):
         """
diff --git a/tinydb/middlewares.py b/tinydb/middlewares.py
index 50c2af2..c978b1c 100644
--- a/tinydb/middlewares.py
+++ b/tinydb/middlewares.py
@@ -84,8 +84,36 @@ class CachingMiddleware(Middleware):
         self.cache = None
         self._cache_modified_count = 0

+    def read(self):
+        """
+        Read data from cache if available, otherwise read from storage.
+        """
+        if self.cache is None:
+            self.cache = self.storage.read()
+        return self.cache
+
+    def write(self, data):
+        """
+        Write data to cache and increment the modified count.
+        Flush to storage if the write cache size is reached.
+        """
+        self.cache = data
+        self._cache_modified_count += 1
+        
+        if self._cache_modified_count >= self.WRITE_CACHE_SIZE:
+            self.flush()
+
     def flush(self):
         """
         Flush all unwritten data to disk.
         """
-        pass
+        if self.cache is not None:
+            self.storage.write(self.cache)
+            self._cache_modified_count = 0
+
+    def close(self):
+        """
+        Flush the cache and close the storage.
+        """
+        self.flush()
+        self.storage.close()
diff --git a/tinydb/mypy_plugin.py b/tinydb/mypy_plugin.py
index 5a0191a..08b3f83 100644
--- a/tinydb/mypy_plugin.py
+++ b/tinydb/mypy_plugin.py
@@ -12,3 +12,28 @@ class TinyDBPlugin(Plugin):
     def __init__(self, options: Options):
         super().__init__(options)
         self.named_placeholders: Dict[str, str] = {}
+
+    def get_dynamic_class_hook(self, fullname: str) -> CB[DynamicClassDef]:
+        if fullname == 'tinydb.utils.with_typehint':
+            return self.with_typehint_callback
+        return None
+
+    def with_typehint_callback(self, ctx: DynamicClassDef) -> None:
+        if len(ctx.call.args) != 1:
+            ctx.api.fail("with_typehint() requires exactly one argument", ctx.call)
+            return
+
+        arg = ctx.call.args[0]
+        if not isinstance(arg, NameExpr):
+            ctx.api.fail("with_typehint() argument must be a type", ctx.call)
+            return
+
+        base_type = ctx.api.lookup_qualified(arg.fullname)
+        if base_type is None:
+            ctx.api.fail(f"Cannot find type '{arg.fullname}'", ctx.call)
+            return
+
+        ctx.cls.info.bases = [base_type]
+
+def plugin(version: str):
+    return TinyDBPlugin
diff --git a/tinydb/operations.py b/tinydb/operations.py
index fdfa678..833860e 100644
--- a/tinydb/operations.py
+++ b/tinydb/operations.py
@@ -13,39 +13,62 @@ def delete(field):
     """
     Delete a given field from the document.
     """
-    pass
+    def transform(doc):
+        if field in doc:
+            del doc[field]
+        return doc
+    return transform


 def add(field, n):
     """
     Add ``n`` to a given field in the document.
     """
-    pass
+    def transform(doc):
+        if field in doc:
+            doc[field] += n
+        return doc
+    return transform


 def subtract(field, n):
     """
     Subtract ``n`` to a given field in the document.
     """
-    pass
+    def transform(doc):
+        if field in doc:
+            doc[field] -= n
+        return doc
+    return transform


 def set(field, val):
     """
     Set a given field to ``val``.
     """
-    pass
+    def transform(doc):
+        doc[field] = val
+        return doc
+    return transform


 def increment(field):
     """
     Increment a given field in the document by 1.
     """
-    pass
+    def transform(doc):
+        if field in doc:
+            doc[field] += 1
+        return doc
+    return transform


 def decrement(field):
     """
     Decrement a given field in the document by 1.
     """
-    pass
+    def transform(doc):
+        if field in doc:
+            doc[field] -= 1
+        return doc
+    return transform
diff --git a/tinydb/queries.py b/tinydb/queries.py
index 0ad5c7e..9ec0435 100644
--- a/tinydb/queries.py
+++ b/tinydb/queries.py
@@ -181,7 +181,21 @@ class Query(QueryInstance):
         :param hashval: The hash of the query.
         :return: A :class:`~tinydb.queries.QueryInstance` object
         """
-        pass
+        if not self._path and not allow_empty_path:
+            raise RuntimeError('Query has no path')
+
+        def runner(value):
+            try:
+                for part in self._path:
+                    if isinstance(part, Callable):
+                        value = part(value)
+                    else:
+                        value = value[part]
+                return test(value)
+            except (KeyError, TypeError, ValueError):
+                return False
+
+        return QueryInstance(runner, hashval)

     def __eq__(self, rhs: Any):
         """
@@ -255,7 +269,7 @@ class Query(QueryInstance):

         >>> Query().f1.exists()
         """
-        pass
+        return self._generate_test(lambda _: True, ('exists', self._path))

     def matches(self, regex: str, flags: int=0) ->QueryInstance:
         """
@@ -266,7 +280,10 @@ class Query(QueryInstance):
         :param regex: The regular expression to use for matching
         :param flags: regex flags to pass to ``re.match``
         """
-        pass
+        return self._generate_test(
+            lambda value: re.match(regex, value, flags) is not None,
+            ('matches', self._path, regex, flags)
+        )

     def search(self, regex: str, flags: int=0) ->QueryInstance:
         """
@@ -278,7 +295,10 @@ class Query(QueryInstance):
         :param regex: The regular expression to use for matching
         :param flags: regex flags to pass to ``re.match``
         """
-        pass
+        return self._generate_test(
+            lambda value: re.search(regex, value, flags) is not None,
+            ('search', self._path, regex, flags)
+        )

     def test(self, func: Callable[[Mapping], bool], *args) ->QueryInstance:
         """
@@ -300,7 +320,10 @@ class Query(QueryInstance):
                      argument
         :param args: Additional arguments to pass to the test function
         """
-        pass
+        return self._generate_test(
+            lambda value: func(value, *args),
+            ('test', self._path, func, args)
+        )

     def any(self, cond: Union[QueryInstance, List[Any]]) ->QueryInstance:
         """
@@ -324,7 +347,14 @@ class Query(QueryInstance):
                      a list of which at least one document has to be contained
                      in the tested document.
         """
-        pass
+        if isinstance(cond, QueryInstance):
+            def test(value):
+                return any(cond(item) for item in value)
+        else:
+            def test(value):
+                return any(item in cond for item in value)
+
+        return self._generate_test(test, ('any', self._path, freeze(cond)))

     def all(self, cond: Union['QueryInstance', List[Any]]) ->QueryInstance:
         """
@@ -346,7 +376,14 @@ class Query(QueryInstance):
         :param cond: Either a query that all documents have to match or a list
                      which has to be contained in the tested document.
         """
-        pass
+        if isinstance(cond, QueryInstance):
+            def test(value):
+                return all(cond(item) for item in value)
+        else:
+            def test(value):
+                return all(item in value for item in cond)
+
+        return self._generate_test(test, ('all', self._path, freeze(cond)))

     def one_of(self, items: List[Any]) ->QueryInstance:
         """
@@ -356,7 +393,8 @@ class Query(QueryInstance):

         :param items: The list of items to check with
         """
-        pass
+        return self._generate_test(lambda value: value in items,
+                                   ('one_of', self._path, freeze(items)))

     def noop(self) ->QueryInstance:
         """
@@ -364,18 +402,21 @@ class Query(QueryInstance):

         Useful for having a base value when composing queries dynamically.
         """
-        pass
+        return self._generate_test(lambda _: True, ('noop',), allow_empty_path=True)

     def map(self, fn: Callable[[Any], Any]) ->'Query':
         """
         Add a function to the query path. Similar to __getattr__ but for
         arbitrary functions.
         """
-        pass
+        query = type(self)()
+        query._path = self._path + (fn,)
+        query._hash = ('path', query._path) if self.is_cacheable() else None
+        return query


 def where(key: str) ->Query:
     """
     A shorthand for ``Query()[key]``
     """
-    pass
+    return Query()[key]
diff --git a/tinydb/storages.py b/tinydb/storages.py
index 0ddc223..16bfa7a 100644
--- a/tinydb/storages.py
+++ b/tinydb/storages.py
@@ -18,7 +18,12 @@ def touch(path: str, create_dirs: bool):
     :param path: The file to create.
     :param create_dirs: Whether to create all missing parent directories.
     """
-    pass
+    if create_dirs:
+        os.makedirs(os.path.dirname(path), exist_ok=True)
+    
+    if not os.path.exists(path):
+        with open(path, 'a'):
+            os.utime(path, None)


 class Storage(ABC):
@@ -38,7 +43,7 @@ class Storage(ABC):

         Return ``None`` here to indicate that the storage is empty.
         """
-        pass
+        raise NotImplementedError

     @abstractmethod
     def write(self, data: Dict[str, Dict[str, Any]]) ->None:
@@ -49,7 +54,7 @@ class Storage(ABC):

         :param data: The current state of the database.
         """
-        pass
+        raise NotImplementedError

     def close(self) ->None:
         """
@@ -88,6 +93,40 @@ class JSONStorage(Storage):
         if any([(character in self._mode) for character in ('+', 'w', 'a')]):
             touch(path, create_dirs=create_dirs)
         self._handle = open(path, mode=self._mode, encoding=encoding)
+        self.path = path
+        self.encoding = encoding
+
+    def read(self) ->Optional[Dict[str, Dict[str, Any]]]:
+        """
+        Read the current state.
+
+        Any kind of deserialization should go here.
+
+        Return ``None`` here to indicate that the storage is empty.
+        """
+        self._handle.seek(0)
+        try:
+            return json.load(self._handle)
+        except json.JSONDecodeError:
+            return None
+
+    def write(self, data: Dict[str, Dict[str, Any]]) ->None:
+        """
+        Write the current state of the database to the storage.
+
+        Any kind of serialization should go here.
+
+        :param data: The current state of the database.
+        """
+        self._handle.seek(0)
+        json.dump(data, self._handle, **self.kwargs)
+        self._handle.truncate()
+
+    def close(self) ->None:
+        """
+        Close open file handles.
+        """
+        self._handle.close()


 class MemoryStorage(Storage):
@@ -101,3 +140,25 @@ class MemoryStorage(Storage):
         """
         super().__init__()
         self.memory = None
+
+    def read(self) ->Optional[Dict[str, Dict[str, Any]]]:
+        """
+        Read the current state from memory.
+
+        Return ``None`` here to indicate that the storage is empty.
+        """
+        return self.memory
+
+    def write(self, data: Dict[str, Dict[str, Any]]) ->None:
+        """
+        Write the current state of the database to memory.
+
+        :param data: The current state of the database.
+        """
+        self.memory = data
+
+    def close(self) ->None:
+        """
+        Clear the memory.
+        """
+        self.memory = None
diff --git a/tinydb/table.py b/tinydb/table.py
index 48eea63..1e85fff 100644
--- a/tinydb/table.py
+++ b/tinydb/table.py
@@ -85,14 +85,14 @@ class Table:
         """
         Get the table name.
         """
-        pass
+        return self._name

     @property
     def storage(self) ->Storage:
         """
         Get the table storage instance.
         """
-        pass
+        return self._storage

     def insert(self, document: Mapping) ->int:
         """
@@ -101,7 +101,10 @@ class Table:
         :param document: the document to insert
         :returns: the inserted document's ID
         """
-        pass
+        doc_id = self._get_next_id()
+        self._update_table(lambda table: table.update({doc_id: document}))
+        self.clear_cache()
+        return doc_id

     def insert_multiple(self, documents: Iterable[Mapping]) ->List[int]:
         """
@@ -110,7 +113,15 @@ class Table:
         :param documents: an Iterable of documents to insert
         :returns: a list containing the inserted documents' IDs
         """
-        pass
+        doc_ids = []
+        def updater(table):
+            for document in documents:
+                doc_id = self._get_next_id()
+                table[doc_id] = document
+                doc_ids.append(doc_id)
+        self._update_table(updater)
+        self.clear_cache()
+        return doc_ids

     def all(self) ->List[Document]:
         """
@@ -118,7 +129,8 @@ class Table:

         :returns: a list with all documents.
         """
-        pass
+        return [self.document_class(doc, self.document_id_class(doc_id))
+                for doc_id, doc in self._read_table().items()]

     def search(self, cond: QueryLike) ->List[Document]:
         """
@@ -127,7 +139,12 @@ class Table:
         :param cond: the condition to check against
         :returns: list of matching documents
         """
-        pass
+        if cond in self._query_cache:
+            return self._query_cache[cond]
+
+        docs = [doc for doc in self.all() if cond(doc)]
+        self._query_cache[cond] = docs
+        return docs

     def get(self, cond: Optional[QueryLike]=None, doc_id: Optional[int]=
         None, doc_ids: Optional[List]=None) ->Optional[Union[Document, List
@@ -145,7 +162,25 @@ class Table:

         :returns: the document(s) or ``None``
         """
-        pass
+        if doc_id is not None:
+            table = self._read_table()
+            if doc_id in table:
+                return self.document_class(table[doc_id], self.document_id_class(doc_id))
+            return None
+        
+        if doc_ids is not None:
+            docs = []
+            table = self._read_table()
+            for id in doc_ids:
+                if id in table:
+                    docs.append(self.document_class(table[id], self.document_id_class(id)))
+            return docs if docs else None
+        
+        if cond is not None:
+            docs = self.search(cond)
+            return docs[0] if docs else None
+        
+        return None

     def contains(self, cond: Optional[QueryLike]=None, doc_id: Optional[int
         ]=None) ->bool:
@@ -158,7 +193,10 @@ class Table:
         :param cond: the condition use
         :param doc_id: the document ID to look for
         """
-        pass
+        if doc_id is not None:
+            return doc_id in self._read_table()
+        
+        return bool(self.search(cond)) if cond is not None else False

     def update(self, fields: Union[Mapping, Callable[[Mapping], None]],
         cond: Optional[QueryLike]=None, doc_ids: Optional[Iterable[int]]=None
@@ -172,7 +210,21 @@ class Table:
         :param doc_ids: a list of document IDs
         :returns: a list containing the updated document's ID
         """
-        pass
+        updated_ids = []
+
+        def updater(table):
+            nonlocal updated_ids
+            for doc_id, doc in table.items():
+                if (doc_ids is None or doc_id in doc_ids) and (cond is None or cond(doc)):
+                    if callable(fields):
+                        fields(doc)
+                    else:
+                        doc.update(fields)
+                    updated_ids.append(doc_id)
+
+        self._update_table(updater)
+        self.clear_cache()
+        return updated_ids

     def update_multiple(self, updates: Iterable[Tuple[Union[Mapping,
         Callable[[Mapping], None]], QueryLike]]) ->List[int]:
@@ -181,7 +233,10 @@ class Table:

         :returns: a list containing the updated document's ID
         """
-        pass
+        updated_ids = []
+        for fields, cond in updates:
+            updated_ids.extend(self.update(fields, cond))
+        return updated_ids

     def upsert(self, document: Mapping, cond: Optional[QueryLike]=None) ->List[
         int]:
@@ -197,7 +252,19 @@ class Table:
         Document with a doc_id
         :returns: a list containing the updated documents' IDs
         """
-        pass
+        if isinstance(document, Document):
+            doc_id = document.doc_id
+            document = dict(document)
+            del document['doc_id']
+            cond = Query().doc_id == doc_id
+
+        if cond is None:
+            return [self.insert(document)]
+        
+        updated = self.update(document, cond)
+        if not updated:
+            return [self.insert(document)]
+        return updated

     def remove(self, cond: Optional[QueryLike]=None, doc_ids: Optional[
         Iterable[int]]=None) ->List[int]:
@@ -208,13 +275,31 @@ class Table:
         :param doc_ids: a list of document IDs
         :returns: a list containing the removed documents' ID
         """
-        pass
+        removed = []
+
+        def remover(table):
+            nonlocal removed
+            if doc_ids is not None:
+                for doc_id in doc_ids:
+                    if doc_id in table:
+                        del table[doc_id]
+                        removed.append(doc_id)
+            else:
+                for doc_id, doc in list(table.items()):
+                    if cond is None or cond(doc):
+                        del table[doc_id]
+                        removed.append(doc_id)
+
+        self._update_table(remover)
+        self.clear_cache()
+        return removed

     def truncate(self) ->None:
         """
         Truncate the table by removing all documents.
         """
-        pass
+        self._update_table(lambda table: table.clear())
+        self.clear_cache()

     def count(self, cond: QueryLike) ->int:
         """
@@ -222,13 +307,13 @@ class Table:

         :param cond: the condition use
         """
-        pass
+        return len(self.search(cond))

     def clear_cache(self) ->None:
         """
         Clear the query cache.
         """
-        pass
+        self._query_cache.clear()

     def __len__(self):
         """
@@ -249,7 +334,11 @@ class Table:
         """
         Return the ID for a newly inserted document.
         """
-        pass
+        if self._next_id is None:
+            self._next_id = max(self._read_table().keys() or [0]) + 1
+        else:
+            self._next_id += 1
+        return self._next_id

     def _read_table(self) ->Dict[str, Mapping]:
         """
@@ -259,7 +348,7 @@ class Table:
         we may not want to convert *all* documents when returning
         only one document for example.
         """
-        pass
+        return self._storage.read() or {}

     def _update_table(self, updater: Callable[[Dict[int, Mapping]], None]):
         """
@@ -274,4 +363,6 @@ class Table:
         As a further optimization, we don't convert the documents into the
         document class, as the table data will *not* be returned to the user.
         """
-        pass
+        data = self._read_table()
+        updater(data)
+        self._storage.write(data)
diff --git a/tinydb/utils.py b/tinydb/utils.py
index 0721622..9957b9e 100644
--- a/tinydb/utils.py
+++ b/tinydb/utils.py
@@ -23,7 +23,8 @@ def with_typehint(baseclass: Type[T]):
     MyPy does not. For that reason TinyDB has a MyPy plugin in
     ``mypy_plugin.py`` that adds support for this pattern.
     """
-    pass
+    # The actual implementation is handled by the MyPy plugin in mypy_plugin.py
+    return baseclass


 class LRUCache(abc.MutableMapping, Generic[K, V]):
@@ -45,26 +46,40 @@ class LRUCache(abc.MutableMapping, Generic[K, V]):
         self.cache: OrderedDict[K, V] = OrderedDict()

     def __len__(self) ->int:
-        return self.length
+        return len(self.cache)

     def __contains__(self, key: object) ->bool:
         return key in self.cache

     def __setitem__(self, key: K, value: V) ->None:
-        self.set(key, value)
+        if key in self.cache:
+            del self.cache[key]
+        elif len(self.cache) >= self.capacity:
+            self.cache.popitem(last=False)
+        self.cache[key] = value

     def __delitem__(self, key: K) ->None:
         del self.cache[key]

-    def __getitem__(self, key) ->V:
-        value = self.get(key)
-        if value is None:
+    def __getitem__(self, key: K) ->V:
+        if key not in self.cache:
             raise KeyError(key)
+        value = self.cache.pop(key)
+        self.cache[key] = value
         return value

     def __iter__(self) ->Iterator[K]:
         return iter(self.cache)

+    def get(self, key: K, default: Optional[D] = None) -> Union[V, D, None]:
+        try:
+            return self[key]
+        except KeyError:
+            return default
+
+    def set(self, key: K, value: V) -> None:
+        self[key] = value
+

 class FrozenDict(dict):
     """
@@ -88,4 +103,10 @@ def freeze(obj):
     """
     Freeze an object by making it immutable and thus hashable.
     """
-    pass
+    if isinstance(obj, dict):
+        return FrozenDict((k, freeze(v)) for k, v in obj.items())
+    elif isinstance(obj, list):
+        return tuple(freeze(i) for i in obj)
+    elif isinstance(obj, set):
+        return frozenset(freeze(i) for i in obj)
+    return obj