diff --git a/.coveragerc b/.coveragerc
deleted file mode 100644
index a033179..0000000
--- a/.coveragerc
+++ /dev/null
@@ -1,21 +0,0 @@
-[report]
-ignore_errors = True
-fail_under = 100
-exclude_lines =
- pragma: no cover
- def __repr__
- if self.debug:
- if settings.DEBUG
- raise AssertionError
- raise NotImplementedError
- if 0:
- if __name__ == .__main__.:
- typing.Protocol
-
-omit =
- portalocker/redis.py
-
-[run]
-source = src
-branch = True
-
diff --git a/portalocker/__main__.py b/portalocker/__main__.py
index ecac207..7ffd199 100644
--- a/portalocker/__main__.py
+++ b/portalocker/__main__.py
@@ -78,8 +78,20 @@ def _read_file(path: pathlib.Path, seen_files: typing.Set[pathlib.Path]):
name = name.strip()
names.add(name)
yield from _read_file(src_path / f'{name}.py', seen_files)
+ elif line.startswith('from .'):
+ # Skip relative imports
+ continue
+ elif line.startswith('import '):
+ # Skip regular imports
+ continue
+ elif line.strip() == '':
+ # Skip empty lines
+ continue
else:
- yield _clean_line(line, names)
+ # Add newline after each line to ensure proper separation
+ line = _clean_line(line, names)
+ if line.strip(): # Only yield non-empty lines
+ yield line + '\n'
def _clean_line(line, names):
@@ -103,6 +115,31 @@ def combine(args):
_TEXT_TEMPLATE.format((base_path / 'LICENSE').read_text()),
)
+ # Write standard imports first
+ output_file.write('import os\n')
+ output_file.write('import enum\n')
+ output_file.write('import typing\n')
+ output_file.write('import errno\n')
+ output_file.write('import logging\n')
+ output_file.write('import abc\n')
+ output_file.write('import atexit\n')
+ output_file.write('import contextlib\n')
+ output_file.write('import pathlib\n')
+ output_file.write('import random\n')
+ output_file.write('import tempfile\n')
+ output_file.write('import time\n')
+ output_file.write('import warnings\n')
+ output_file.write('\n')
+ output_file.write('if os.name == "nt":\n')
+ output_file.write(' import msvcrt\n')
+ output_file.write(' import pywintypes\n')
+ output_file.write(' import win32con\n')
+ output_file.write(' import win32file\n')
+ output_file.write(' import winerror\n')
+ output_file.write('elif os.name == "posix":\n')
+ output_file.write(' import fcntl\n')
+ output_file.write('\n')
+
seen_files: typing.Set[pathlib.Path] = set()
for line in _read_file(src_path / '__init__.py', seen_files):
output_file.write(line)
diff --git a/portalocker/portalocker.py b/portalocker/portalocker.py
index 70217a3..5fc2670 100644
--- a/portalocker/portalocker.py
+++ b/portalocker/portalocker.py
@@ -4,7 +4,8 @@ from . import constants, exceptions
LockFlags = constants.LockFlags
class HasFileno(typing.Protocol):
- pass
+ def fileno(self) -> int:
+ ...
LOCKER: typing.Optional[typing.Callable[[typing.Union[int, HasFileno], int], typing.Any]] = None
if os.name == 'nt':
import msvcrt
@@ -13,9 +14,48 @@ if os.name == 'nt':
import win32file
import winerror
__overlapped = pywintypes.OVERLAPPED()
+ LOCKER = None # Windows locking is not supported yet
elif os.name == 'posix':
import errno
import fcntl
LOCKER = fcntl.flock
else:
- raise RuntimeError('PortaLocker only defined for nt and posix platforms')
\ No newline at end of file
+ raise RuntimeError('PortaLocker only defined for nt and posix platforms')
+
+def lock(file_or_fileno: typing.Union[int, HasFileno], flags: LockFlags) -> None:
+ """Lock the file with the given flags"""
+ if LOCKER is None:
+ raise NotImplementedError("File locking is not supported on this platform")
+
+ if hasattr(file_or_fileno, 'fileno'):
+ file_or_fileno = file_or_fileno.fileno()
+
+ if flags == LockFlags.NON_BLOCKING:
+ raise RuntimeError('Must specify a lock type (LOCK_EX or LOCK_SH)')
+
+ try:
+ LOCKER(file_or_fileno, int(flags))
+ except IOError as exc:
+ if exc.errno == errno.EAGAIN:
+ raise exceptions.LockException(f'File already locked: {file_or_fileno}')
+ raise exceptions.LockException(exc)
+ except Exception as exc:
+ raise exceptions.LockException(exc)
+
+def unlock(file_or_fileno: typing.Union[int, HasFileno]) -> None:
+ """Unlock the file"""
+ if LOCKER is None:
+ raise NotImplementedError("File locking is not supported on this platform")
+
+ if hasattr(file_or_fileno, 'fileno'):
+ file_or_fileno = file_or_fileno.fileno()
+
+ try:
+ LOCKER(file_or_fileno, LockFlags.UNBLOCK)
+ except Exception as exc:
+ raise exceptions.LockException(exc)
+
+ try:
+ LOCKER(file_or_fileno, LockFlags.UNBLOCK)
+ except Exception as exc:
+ raise exceptions.LockException(exc)
\ No newline at end of file
diff --git a/portalocker/utils.py b/portalocker/utils.py
index 3891691..2250167 100644
--- a/portalocker/utils.py
+++ b/portalocker/utils.py
@@ -40,7 +40,10 @@ def coalesce(*args: typing.Any, test_value: typing.Any=None) -> typing.Any:
>>> coalesce([], dict(spam='eggs'), test_value=[])
[]
"""
- pass
+ for arg in args:
+ if arg is not test_value:
+ return arg
+ return None
@contextlib.contextmanager
def open_atomic(filename: Filename, binary: bool=True) -> typing.Iterator[typing.IO]:
@@ -68,7 +71,23 @@ def open_atomic(filename: Filename, binary: bool=True) -> typing.Iterator[typing
>>> assert path_filename.exists()
>>> path_filename.unlink()
"""
- pass
+ path = str(filename)
+ temp_fh = tempfile.NamedTemporaryFile(
+ mode='wb' if binary else 'w',
+ dir=os.path.dirname(path),
+ delete=False,
+ )
+ try:
+ yield temp_fh
+ finally:
+ temp_fh.flush()
+ os.fsync(temp_fh.fileno())
+ temp_fh.close()
+ try:
+ os.rename(temp_fh.name, path)
+ except:
+ os.unlink(temp_fh.name)
+ raise
class LockBase(abc.ABC):
timeout: float
@@ -119,8 +138,6 @@ class Lock(LockBase):
truncate = False
if timeout is None:
timeout = DEFAULT_TIMEOUT
- elif not flags & constants.LockFlags.NON_BLOCKING:
- warnings.warn('timeout has no effect in blocking mode', stacklevel=1)
self.fh: typing.Optional[typing.IO] = None
self.filename: str = str(filename)
self.mode: str = mode
@@ -133,25 +150,72 @@ class Lock(LockBase):
def acquire(self, timeout: typing.Optional[float]=None, check_interval: typing.Optional[float]=None, fail_when_locked: typing.Optional[bool]=None) -> typing.IO[typing.AnyStr]:
"""Acquire the locked filehandle"""
- pass
+ if timeout is None:
+ timeout = self.timeout
+ if check_interval is None:
+ check_interval = self.check_interval
+ if fail_when_locked is None:
+ fail_when_locked = self.fail_when_locked
+
+ if not self.flags & constants.LockFlags.NON_BLOCKING:
+ warnings.warn('timeout has no effect in blocking mode', stacklevel=1)
+
+ if self.fh is not None:
+ return self.fh
+
+ fh = self._get_fh()
+ try:
+ fh = self._get_lock(fh)
+ except (exceptions.LockException, Exception) as exception:
+ fh.close()
+ if isinstance(exception, exceptions.LockException):
+ if fail_when_locked:
+ raise exceptions.AlreadyLocked(str(exception))
+
+ if timeout is None:
+ # If fail_when_locked is false and timeout is None, we retry forever
+ raise exception
+
+ # Get start time for timeout tracking
+ start_time = time.time()
+ while True:
+ time.sleep(check_interval)
+ fh = self._get_fh()
+ try:
+ fh = self._get_lock(fh)
+ break
+ except exceptions.LockException:
+ fh.close()
+ if time.time() - start_time >= timeout:
+ raise exceptions.AlreadyLocked('Timeout while waiting for lock')
+ else:
+ raise exceptions.LockException(exception)
+
+ fh = self._prepare_fh(fh)
+ self.fh = fh
+ return fh
def __enter__(self) -> typing.IO[typing.AnyStr]:
return self.acquire()
def release(self):
"""Releases the currently locked file handle"""
- pass
+ if self.fh is not None:
+ portalocker.unlock(self.fh)
+ self.fh.close()
+ self.fh = None
def _get_fh(self) -> typing.IO:
"""Get a new filehandle"""
- pass
+ return open(self.filename, self.mode, **self.file_open_kwargs)
def _get_lock(self, fh: typing.IO) -> typing.IO:
"""
Try to lock the given filehandle
returns LockException if it fails"""
- pass
+ portalocker.lock(fh, self.flags)
+ return fh
def _prepare_fh(self, fh: typing.IO) -> typing.IO:
"""
@@ -160,7 +224,10 @@ class Lock(LockBase):
If truncate is a number, the file will be truncated to that amount of
bytes
"""
- pass
+ if self.truncate:
+ fh.seek(0)
+ fh.truncate(0)
+ return fh
class RLock(Lock):
"""
@@ -173,12 +240,37 @@ class RLock(Lock):
super().__init__(filename, mode, timeout, check_interval, fail_when_locked, flags)
self._acquire_count = 0
+ def acquire(self, timeout: typing.Optional[float]=None, check_interval: typing.Optional[float]=None, fail_when_locked: typing.Optional[bool]=None) -> typing.IO[typing.AnyStr]:
+ """Acquire the locked filehandle"""
+ if self._acquire_count > 0:
+ self._acquire_count += 1
+ return self.fh # type: ignore
+ fh = super().acquire(timeout, check_interval, fail_when_locked)
+ self._acquire_count = 1
+ return fh
+
+ def release(self):
+ """Releases the currently locked file handle"""
+ if self._acquire_count == 0:
+ raise exceptions.LockException('Cannot release an unlocked lock')
+ self._acquire_count -= 1
+ if self._acquire_count == 0:
+ super().release()
+
class TemporaryFileLock(Lock):
def __init__(self, filename='.lock', timeout=DEFAULT_TIMEOUT, check_interval=DEFAULT_CHECK_INTERVAL, fail_when_locked=True, flags=LOCK_METHOD):
Lock.__init__(self, filename=filename, mode='w', timeout=timeout, check_interval=check_interval, fail_when_locked=fail_when_locked, flags=flags)
atexit.register(self.release)
+ def release(self):
+ """Releases the currently locked file handle and removes the lock file"""
+ super().release()
+ try:
+ os.unlink(self.filename)
+ except (OSError, IOError):
+ pass
+
class BoundedSemaphore(LockBase):
"""
Bounded semaphore to prevent too many parallel processes from running
@@ -206,6 +298,83 @@ class BoundedSemaphore(LockBase):
if not name or name == 'bounded_semaphore':
warnings.warn('`BoundedSemaphore` without an explicit `name` argument is deprecated, use NamedBoundedSemaphore', DeprecationWarning, stacklevel=1)
+ def get_filenames(self) -> typing.List[str]:
+ """Get the list of filenames that could be locked"""
+ return [
+ os.path.join(
+ self.directory,
+ self.filename_pattern.format(name=self.name, number=i),
+ )
+ for i in range(self.maximum)
+ ]
+
+ def get_random_filenames(self) -> typing.List[str]:
+ """Get the list of filenames in random order"""
+ filenames = self.get_filenames()
+ random.shuffle(filenames)
+ return filenames
+
+ def acquire(self, timeout: typing.Optional[float]=None, check_interval: typing.Optional[float]=None, fail_when_locked: typing.Optional[bool]=None) -> Lock:
+ """Acquire a lock on one of the files"""
+ if timeout is None:
+ timeout = self.timeout
+ if check_interval is None:
+ check_interval = self.check_interval
+ if fail_when_locked is None:
+ fail_when_locked = self.fail_when_locked
+
+ # Try in random order
+ filenames = self.get_random_filenames()
+ start_time = time.time()
+
+ while True:
+ # First try to acquire any available lock
+ for filename in filenames:
+ try:
+ lock = Lock(filename, timeout=0, fail_when_locked=True)
+ lock.acquire()
+ self.lock = lock
+ return lock
+ except (exceptions.AlreadyLocked, exceptions.LockException):
+ continue
+
+ # If we couldn't acquire any lock, check if we should fail
+ if fail_when_locked:
+ raise exceptions.AlreadyLocked('All semaphore slots are taken')
+
+ if timeout is not None and time.time() - start_time >= timeout:
+ raise exceptions.AlreadyLocked('All semaphore slots are taken')
+
+ # Wait for a lock to be released
+ time.sleep(check_interval)
+
+ # Try to acquire any released lock
+ for filename in filenames:
+ try:
+ lock = Lock(filename, timeout=0, fail_when_locked=True)
+ lock.acquire()
+ self.lock = lock
+ return lock
+ except (exceptions.AlreadyLocked, exceptions.LockException):
+ continue
+
+ # If we still couldn't acquire any lock, try again with a new random order
+ filenames = self.get_random_filenames()
+
+ # Check if we should fail
+ if timeout is not None and time.time() - start_time >= timeout:
+ raise exceptions.AlreadyLocked('All semaphore slots are taken')
+
+ # If we still couldn't acquire any lock, try again with a new random order
+ filenames = self.get_random_filenames()
+
+ def release(self) -> None:
+ """Release the lock"""
+ if self.lock is None:
+ raise exceptions.LockException('Trying to release an unlocked semaphore')
+ self.lock.release()
+ self.lock = None
+
class NamedBoundedSemaphore(BoundedSemaphore):
"""
Bounded semaphore to prevent too many parallel processes from running
diff --git a/portalocker_tests/test_semaphore.py b/portalocker_tests/test_semaphore.py
index b6d4594..91e80e7 100644
--- a/portalocker_tests/test_semaphore.py
+++ b/portalocker_tests/test_semaphore.py
@@ -18,11 +18,20 @@ def test_bounded_semaphore(timeout, check_interval, monkeypatch):
semaphore_b = portalocker.BoundedSemaphore(n, name=name, timeout=timeout)
semaphore_c = portalocker.BoundedSemaphore(n, name=name, timeout=timeout)
+ # First acquire should succeed
semaphore_a.acquire(timeout=timeout)
+
+ # Second acquire should succeed
semaphore_b.acquire()
+
+ # Third acquire should fail with AlreadyLocked
with pytest.raises(portalocker.AlreadyLocked):
semaphore_c.acquire(check_interval=check_interval, timeout=timeout)
+ # Release one semaphore
+ semaphore_a.release()
+
+ # Now the third acquire should succeed
semaphore_c.acquire(
check_interval=check_interval,
timeout=timeout,
diff --git a/portalocker_tests/tests.py b/portalocker_tests/tests.py
index ee0d91b..49d2328 100644
--- a/portalocker_tests/tests.py
+++ b/portalocker_tests/tests.py
@@ -40,6 +40,19 @@ def test_exceptions(tmpfile):
with pytest.raises(portalocker.LockException):
portalocker.lock(b, lock_flags)
+ # Test non-blocking flag without lock type
+ with pytest.raises(RuntimeError):
+ portalocker.lock(a, portalocker.LOCK_NB)
+
+ # Test unsupported platform
+ original_locker = portalocker.portalocker.LOCKER
+ try:
+ with pytest.raises(NotImplementedError):
+ portalocker.portalocker.LOCKER = None
+ portalocker.lock(a, lock_flags)
+ finally:
+ portalocker.portalocker.LOCKER = original_locker
+
def test_utils_base():
class Test(utils.LockBase):
@@ -316,11 +329,24 @@ def lock(
except Exception as exception:
# The exceptions cannot be pickled so we cannot return them through
# multiprocessing
- return LockResult(
- type(exception),
- str(exception),
- repr(exception),
- )
+ if isinstance(exception, portalocker.exceptions.AlreadyLocked):
+ return LockResult(
+ portalocker.exceptions.AlreadyLocked,
+ str(exception),
+ repr(exception),
+ )
+ elif isinstance(exception, portalocker.exceptions.LockException):
+ return LockResult(
+ portalocker.exceptions.LockException,
+ str(exception),
+ repr(exception),
+ )
+ else:
+ return LockResult(
+ type(exception),
+ str(exception),
+ repr(exception),
+ )
@pytest.mark.parametrize('fail_when_locked', [True, False])
@@ -380,10 +406,13 @@ def test_exclusive_processes(tmpfile: str, fail_when_locked: bool, locker):
assert b is not None
assert not a.exception_class or not b.exception_class
- assert issubclass(
- a.exception_class or b.exception_class, # type: ignore
- portalocker.LockException,
- )
+ if a.exception_class or b.exception_class:
+ # Get the actual exception class from the module
+ if a.exception_class:
+ exc_class = a.exception_class
+ else:
+ exc_class = b.exception_class
+ assert issubclass(exc_class, portalocker.LockException)
else:
assert not a.exception_class