back to Reference (Gold) summary
Reference (Gold): paramiko
Pytest Summary for test tests
status | count |
---|---|
passed | 557 |
skipped | 21 |
total | 578 |
collected | 578 |
Failed pytests:
Patch diff
diff --git a/paramiko/_version.py b/paramiko/_version.py
index 9890fe29..65688cab 100644
--- a/paramiko/_version.py
+++ b/paramiko/_version.py
@@ -1,2 +1,2 @@
-__version_info__ = 3, 4, 1
-__version__ = '.'.join(map(str, __version_info__))
+__version_info__ = (3, 4, 1)
+__version__ = ".".join(map(str, __version_info__))
diff --git a/paramiko/_winapi.py b/paramiko/_winapi.py
index f02b3c7d..42954574 100644
--- a/paramiko/_winapi.py
+++ b/paramiko/_winapi.py
@@ -5,17 +5,51 @@ in jaraco.windows (3.4.1).
If you encounter issues with this module, please consider reporting the issues
in jaraco.windows and asking the author to port the fixes back here.
"""
+
import builtins
import ctypes.wintypes
+
from paramiko.util import u
+######################
+# jaraco.windows.error
+
+
def format_system_message(errno):
"""
Call FormatMessage with a system error number to retrieve
the descriptive error message.
"""
- pass
+ # first some flags used by FormatMessageW
+ ALLOCATE_BUFFER = 0x100
+ FROM_SYSTEM = 0x1000
+
+ # Let FormatMessageW allocate the buffer (we'll free it below)
+ # Also, let it know we want a system error message.
+ flags = ALLOCATE_BUFFER | FROM_SYSTEM
+ source = None
+ message_id = errno
+ language_id = 0
+ result_buffer = ctypes.wintypes.LPWSTR()
+ buffer_size = 0
+ arguments = None
+ bytes = ctypes.windll.kernel32.FormatMessageW(
+ flags,
+ source,
+ message_id,
+ language_id,
+ ctypes.byref(result_buffer),
+ buffer_size,
+ arguments,
+ )
+ # note the following will cause an infinite loop if GetLastError
+ # repeatedly returns an error that cannot be formatted, although
+ # this should not happen.
+ handle_nonzero_success(bytes)
+ message = result_buffer.value
+ ctypes.windll.kernel32.LocalFree(result_buffer)
+ return message
class WindowsError(builtins.WindowsError):
@@ -29,38 +63,71 @@ class WindowsError(builtins.WindowsError):
args = 0, strerror, None, value
super().__init__(*args)
+ @property
+ def message(self):
+ return self.strerror
+
+ @property
+ def code(self):
+ return self.winerror
+
def __str__(self):
return self.message
def __repr__(self):
- return '{self.__class__.__name__}({self.winerror})'.format(**vars())
+ return "{self.__class__.__name__}({self.winerror})".format(**vars())
+
+def handle_nonzero_success(result):
+ if result == 0:
+ raise WindowsError()
+
+
+###########################
+# jaraco.windows.api.memory
+
+GMEM_MOVEABLE = 0x2
-GMEM_MOVEABLE = 2
GlobalAlloc = ctypes.windll.kernel32.GlobalAlloc
GlobalAlloc.argtypes = ctypes.wintypes.UINT, ctypes.c_size_t
GlobalAlloc.restype = ctypes.wintypes.HANDLE
+
GlobalLock = ctypes.windll.kernel32.GlobalLock
-GlobalLock.argtypes = ctypes.wintypes.HGLOBAL,
+GlobalLock.argtypes = (ctypes.wintypes.HGLOBAL,)
GlobalLock.restype = ctypes.wintypes.LPVOID
+
GlobalUnlock = ctypes.windll.kernel32.GlobalUnlock
-GlobalUnlock.argtypes = ctypes.wintypes.HGLOBAL,
+GlobalUnlock.argtypes = (ctypes.wintypes.HGLOBAL,)
GlobalUnlock.restype = ctypes.wintypes.BOOL
+
GlobalSize = ctypes.windll.kernel32.GlobalSize
-GlobalSize.argtypes = ctypes.wintypes.HGLOBAL,
+GlobalSize.argtypes = (ctypes.wintypes.HGLOBAL,)
GlobalSize.restype = ctypes.c_size_t
+
CreateFileMapping = ctypes.windll.kernel32.CreateFileMappingW
-CreateFileMapping.argtypes = [ctypes.wintypes.HANDLE, ctypes.c_void_p,
- ctypes.wintypes.DWORD, ctypes.wintypes.DWORD, ctypes.wintypes.DWORD,
- ctypes.wintypes.LPWSTR]
+CreateFileMapping.argtypes = [
+ ctypes.wintypes.HANDLE,
+ ctypes.c_void_p,
+ ctypes.wintypes.DWORD,
+ ctypes.wintypes.DWORD,
+ ctypes.wintypes.DWORD,
+ ctypes.wintypes.LPWSTR,
+]
CreateFileMapping.restype = ctypes.wintypes.HANDLE
+
MapViewOfFile = ctypes.windll.kernel32.MapViewOfFile
MapViewOfFile.restype = ctypes.wintypes.HANDLE
+
UnmapViewOfFile = ctypes.windll.kernel32.UnmapViewOfFile
-UnmapViewOfFile.argtypes = ctypes.wintypes.HANDLE,
+UnmapViewOfFile.argtypes = (ctypes.wintypes.HANDLE,)
+
RtlMoveMemory = ctypes.windll.kernel32.RtlMoveMemory
-RtlMoveMemory.argtypes = ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t
-ctypes.windll.kernel32.LocalFree.argtypes = ctypes.wintypes.HLOCAL,
+RtlMoveMemory.argtypes = (ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t)
+
+ctypes.windll.kernel32.LocalFree.argtypes = (ctypes.wintypes.HLOCAL,)
+
+#####################
+# jaraco.windows.mmap
class MemoryMap:
@@ -75,69 +142,128 @@ class MemoryMap:
self.pos = 0
def __enter__(self):
- p_SA = ctypes.byref(self.security_attributes
- ) if self.security_attributes else None
+ p_SA = (
+ ctypes.byref(self.security_attributes)
+ if self.security_attributes
+ else None
+ )
INVALID_HANDLE_VALUE = -1
- PAGE_READWRITE = 4
- FILE_MAP_WRITE = 2
+ PAGE_READWRITE = 0x4
+ FILE_MAP_WRITE = 0x2
filemap = ctypes.windll.kernel32.CreateFileMappingW(
- INVALID_HANDLE_VALUE, p_SA, PAGE_READWRITE, 0, self.length, u(
- self.name))
+ INVALID_HANDLE_VALUE,
+ p_SA,
+ PAGE_READWRITE,
+ 0,
+ self.length,
+ u(self.name),
+ )
handle_nonzero_success(filemap)
if filemap == INVALID_HANDLE_VALUE:
- raise Exception('Failed to create file mapping')
+ raise Exception("Failed to create file mapping")
self.filemap = filemap
self.view = MapViewOfFile(filemap, FILE_MAP_WRITE, 0, 0, 0)
return self
+ def seek(self, pos):
+ self.pos = pos
+
+ def write(self, msg):
+ assert isinstance(msg, bytes)
+ n = len(msg)
+ if self.pos + n >= self.length: # A little safety.
+ raise ValueError(f"Refusing to write {n} bytes")
+ dest = self.view + self.pos
+ length = ctypes.c_size_t(n)
+ ctypes.windll.kernel32.RtlMoveMemory(dest, msg, length)
+ self.pos += n
+
def read(self, n):
"""
Read n bytes from mapped view.
"""
- pass
+ out = ctypes.create_string_buffer(n)
+ source = self.view + self.pos
+ length = ctypes.c_size_t(n)
+ ctypes.windll.kernel32.RtlMoveMemory(out, source, length)
+ self.pos += n
+ return out.raw
def __exit__(self, exc_type, exc_val, tb):
ctypes.windll.kernel32.UnmapViewOfFile(self.view)
ctypes.windll.kernel32.CloseHandle(self.filemap)
-READ_CONTROL = 131072
-STANDARD_RIGHTS_REQUIRED = 983040
+#############################
+# jaraco.windows.api.security
+
+# from WinNT.h
+READ_CONTROL = 0x00020000
+STANDARD_RIGHTS_REQUIRED = 0x000F0000
STANDARD_RIGHTS_READ = READ_CONTROL
STANDARD_RIGHTS_WRITE = READ_CONTROL
STANDARD_RIGHTS_EXECUTE = READ_CONTROL
-STANDARD_RIGHTS_ALL = 2031616
-POLICY_VIEW_LOCAL_INFORMATION = 1
-POLICY_VIEW_AUDIT_INFORMATION = 2
-POLICY_GET_PRIVATE_INFORMATION = 4
-POLICY_TRUST_ADMIN = 8
-POLICY_CREATE_ACCOUNT = 16
-POLICY_CREATE_SECRET = 32
-POLICY_CREATE_PRIVILEGE = 64
-POLICY_SET_DEFAULT_QUOTA_LIMITS = 128
-POLICY_SET_AUDIT_REQUIREMENTS = 256
-POLICY_AUDIT_LOG_ADMIN = 512
-POLICY_SERVER_ADMIN = 1024
-POLICY_LOOKUP_NAMES = 2048
-POLICY_NOTIFICATION = 4096
-POLICY_ALL_ACCESS = (STANDARD_RIGHTS_REQUIRED |
- POLICY_VIEW_LOCAL_INFORMATION | POLICY_VIEW_AUDIT_INFORMATION |
- POLICY_GET_PRIVATE_INFORMATION | POLICY_TRUST_ADMIN |
- POLICY_CREATE_ACCOUNT | POLICY_CREATE_SECRET | POLICY_CREATE_PRIVILEGE |
- POLICY_SET_DEFAULT_QUOTA_LIMITS | POLICY_SET_AUDIT_REQUIREMENTS |
- POLICY_AUDIT_LOG_ADMIN | POLICY_SERVER_ADMIN | POLICY_LOOKUP_NAMES)
-POLICY_READ = (STANDARD_RIGHTS_READ | POLICY_VIEW_AUDIT_INFORMATION |
- POLICY_GET_PRIVATE_INFORMATION)
-POLICY_WRITE = (STANDARD_RIGHTS_WRITE | POLICY_TRUST_ADMIN |
- POLICY_CREATE_ACCOUNT | POLICY_CREATE_SECRET | POLICY_CREATE_PRIVILEGE |
- POLICY_SET_DEFAULT_QUOTA_LIMITS | POLICY_SET_AUDIT_REQUIREMENTS |
- POLICY_AUDIT_LOG_ADMIN | POLICY_SERVER_ADMIN)
-POLICY_EXECUTE = (STANDARD_RIGHTS_EXECUTE | POLICY_VIEW_LOCAL_INFORMATION |
- POLICY_LOOKUP_NAMES)
+STANDARD_RIGHTS_ALL = 0x001F0000
+
+# from NTSecAPI.h
+POLICY_VIEW_LOCAL_INFORMATION = 0x00000001
+POLICY_VIEW_AUDIT_INFORMATION = 0x00000002
+POLICY_GET_PRIVATE_INFORMATION = 0x00000004
+POLICY_TRUST_ADMIN = 0x00000008
+POLICY_CREATE_ACCOUNT = 0x00000010
+POLICY_CREATE_SECRET = 0x00000020
+POLICY_CREATE_PRIVILEGE = 0x00000040
+POLICY_SET_DEFAULT_QUOTA_LIMITS = 0x00000080
+POLICY_SET_AUDIT_REQUIREMENTS = 0x00000100
+POLICY_AUDIT_LOG_ADMIN = 0x00000200
+POLICY_SERVER_ADMIN = 0x00000400
+POLICY_LOOKUP_NAMES = 0x00000800
+POLICY_NOTIFICATION = 0x00001000
+
+POLICY_ALL_ACCESS = (
+ STANDARD_RIGHTS_REQUIRED
+ | POLICY_VIEW_LOCAL_INFORMATION
+ | POLICY_VIEW_AUDIT_INFORMATION
+ | POLICY_GET_PRIVATE_INFORMATION
+ | POLICY_TRUST_ADMIN
+ | POLICY_CREATE_ACCOUNT
+ | POLICY_CREATE_SECRET
+ | POLICY_CREATE_PRIVILEGE
+ | POLICY_SET_DEFAULT_QUOTA_LIMITS
+ | POLICY_SET_AUDIT_REQUIREMENTS
+ | POLICY_AUDIT_LOG_ADMIN
+ | POLICY_SERVER_ADMIN
+ | POLICY_LOOKUP_NAMES
+)
+
+
+POLICY_READ = (
+ STANDARD_RIGHTS_READ
+ | POLICY_VIEW_AUDIT_INFORMATION
+ | POLICY_GET_PRIVATE_INFORMATION
+)
+
+POLICY_WRITE = (
+ STANDARD_RIGHTS_WRITE
+ | POLICY_TRUST_ADMIN
+ | POLICY_CREATE_ACCOUNT
+ | POLICY_CREATE_SECRET
+ | POLICY_CREATE_PRIVILEGE
+ | POLICY_SET_DEFAULT_QUOTA_LIMITS
+ | POLICY_SET_AUDIT_REQUIREMENTS
+ | POLICY_AUDIT_LOG_ADMIN
+ | POLICY_SERVER_ADMIN
+)
+
+POLICY_EXECUTE = (
+ STANDARD_RIGHTS_EXECUTE
+ | POLICY_VIEW_LOCAL_INFORMATION
+ | POLICY_LOOKUP_NAMES
+)
class TokenAccess:
- TOKEN_QUERY = 8
+ TOKEN_QUERY = 0x8
class TokenInformationClass:
@@ -146,8 +272,10 @@ class TokenInformationClass:
class TOKEN_USER(ctypes.Structure):
num = 1
- _fields_ = [('SID', ctypes.c_void_p), ('ATTRIBUTES', ctypes.wintypes.DWORD)
- ]
+ _fields_ = [
+ ("SID", ctypes.c_void_p),
+ ("ATTRIBUTES", ctypes.wintypes.DWORD),
+ ]
class SECURITY_DESCRIPTOR(ctypes.Structure):
@@ -163,12 +291,19 @@ class SECURITY_DESCRIPTOR(ctypes.Structure):
PACL Dacl;
} SECURITY_DESCRIPTOR;
"""
+
SECURITY_DESCRIPTOR_CONTROL = ctypes.wintypes.USHORT
REVISION = 1
- _fields_ = [('Revision', ctypes.c_ubyte), ('Sbz1', ctypes.c_ubyte), (
- 'Control', SECURITY_DESCRIPTOR_CONTROL), ('Owner', ctypes.c_void_p),
- ('Group', ctypes.c_void_p), ('Sacl', ctypes.c_void_p), ('Dacl',
- ctypes.c_void_p)]
+
+ _fields_ = [
+ ("Revision", ctypes.c_ubyte),
+ ("Sbz1", ctypes.c_ubyte),
+ ("Control", SECURITY_DESCRIPTOR_CONTROL),
+ ("Owner", ctypes.c_void_p),
+ ("Group", ctypes.c_void_p),
+ ("Sacl", ctypes.c_void_p),
+ ("Dacl", ctypes.c_void_p),
+ ]
class SECURITY_ATTRIBUTES(ctypes.Structure):
@@ -179,30 +314,77 @@ class SECURITY_ATTRIBUTES(ctypes.Structure):
BOOL bInheritHandle;
} SECURITY_ATTRIBUTES;
"""
- _fields_ = [('nLength', ctypes.wintypes.DWORD), ('lpSecurityDescriptor',
- ctypes.c_void_p), ('bInheritHandle', ctypes.wintypes.BOOL)]
+
+ _fields_ = [
+ ("nLength", ctypes.wintypes.DWORD),
+ ("lpSecurityDescriptor", ctypes.c_void_p),
+ ("bInheritHandle", ctypes.wintypes.BOOL),
+ ]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.nLength = ctypes.sizeof(SECURITY_ATTRIBUTES)
+ @property
+ def descriptor(self):
+ return self._descriptor
+
+ @descriptor.setter
+ def descriptor(self, value):
+ self._descriptor = value
+ self.lpSecurityDescriptor = ctypes.addressof(value)
+
+
+ctypes.windll.advapi32.SetSecurityDescriptorOwner.argtypes = (
+ ctypes.POINTER(SECURITY_DESCRIPTOR),
+ ctypes.c_void_p,
+ ctypes.wintypes.BOOL,
+)
-ctypes.windll.advapi32.SetSecurityDescriptorOwner.argtypes = ctypes.POINTER(
- SECURITY_DESCRIPTOR), ctypes.c_void_p, ctypes.wintypes.BOOL
+#########################
+# jaraco.windows.security
def GetTokenInformation(token, information_class):
"""
Given a token, get the token information for it.
"""
- pass
+ data_size = ctypes.wintypes.DWORD()
+ ctypes.windll.advapi32.GetTokenInformation(
+ token, information_class.num, 0, 0, ctypes.byref(data_size)
+ )
+ data = ctypes.create_string_buffer(data_size.value)
+ handle_nonzero_success(
+ ctypes.windll.advapi32.GetTokenInformation(
+ token,
+ information_class.num,
+ ctypes.byref(data),
+ ctypes.sizeof(data),
+ ctypes.byref(data_size),
+ )
+ )
+ return ctypes.cast(data, ctypes.POINTER(TOKEN_USER)).contents
+
+
+def OpenProcessToken(proc_handle, access):
+ result = ctypes.wintypes.HANDLE()
+ proc_handle = ctypes.wintypes.HANDLE(proc_handle)
+ handle_nonzero_success(
+ ctypes.windll.advapi32.OpenProcessToken(
+ proc_handle, access, ctypes.byref(result)
+ )
+ )
+ return result
def get_current_user():
"""
Return a TOKEN_USER for the owner of this process.
"""
- pass
+ process = OpenProcessToken(
+ ctypes.windll.kernel32.GetCurrentProcess(), TokenAccess.TOKEN_QUERY
+ )
+ return GetTokenInformation(process, TOKEN_USER)
def get_security_attributes_for_user(user=None):
@@ -210,4 +392,22 @@ def get_security_attributes_for_user(user=None):
Return a SECURITY_ATTRIBUTES structure with the SID set to the
specified user (uses current user if none is specified).
"""
- pass
+ if user is None:
+ user = get_current_user()
+
+ assert isinstance(user, TOKEN_USER), "user must be TOKEN_USER instance"
+
+ SD = SECURITY_DESCRIPTOR()
+ SA = SECURITY_ATTRIBUTES()
+ # by attaching the actual security descriptor, it will be garbage-
+ # collected with the security attributes
+ SA.descriptor = SD
+ SA.bInheritHandle = 1
+
+ ctypes.windll.advapi32.InitializeSecurityDescriptor(
+ ctypes.byref(SD), SECURITY_DESCRIPTOR.REVISION
+ )
+ ctypes.windll.advapi32.SetSecurityDescriptorOwner(
+ ctypes.byref(SD), user.SID, 0
+ )
+ return SA
diff --git a/paramiko/agent.py b/paramiko/agent.py
index 440c59d9..b29a0d14 100644
--- a/paramiko/agent.py
+++ b/paramiko/agent.py
@@ -1,6 +1,25 @@
+# Copyright (C) 2003-2007 John Rochester <john@jrochester.org>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
"""
SSH Agent interface
"""
+
import os
import socket
import struct
@@ -12,24 +31,34 @@ import stat
from logging import DEBUG
from select import select
from paramiko.common import io_sleep, byte_chr
+
from paramiko.ssh_exception import SSHException, AuthenticationException
from paramiko.message import Message
from paramiko.pkey import PKey, UnknownKeyType
from paramiko.util import asbytes, get_logger
+
cSSH2_AGENTC_REQUEST_IDENTITIES = byte_chr(11)
SSH2_AGENT_IDENTITIES_ANSWER = 12
cSSH2_AGENTC_SIGN_REQUEST = byte_chr(13)
SSH2_AGENT_SIGN_RESPONSE = 14
+
SSH_AGENT_RSA_SHA2_256 = 2
SSH_AGENT_RSA_SHA2_512 = 4
-ALGORITHM_FLAG_MAP = {'rsa-sha2-256': SSH_AGENT_RSA_SHA2_256,
- 'rsa-sha2-512': SSH_AGENT_RSA_SHA2_512}
+# NOTE: RFC mildly confusing; while these flags are OR'd together, OpenSSH at
+# least really treats them like "AND"s, in the sense that if it finds the
+# SHA256 flag set it won't continue looking at the SHA512 one; it
+# short-circuits right away.
+# Thus, we never want to eg submit 6 to say "either's good".
+ALGORITHM_FLAG_MAP = {
+ "rsa-sha2-256": SSH_AGENT_RSA_SHA2_256,
+ "rsa-sha2-512": SSH_AGENT_RSA_SHA2_512,
+}
for key, value in list(ALGORITHM_FLAG_MAP.items()):
- ALGORITHM_FLAG_MAP[f'{key}-cert-v01@openssh.com'] = value
+ ALGORITHM_FLAG_MAP[f"{key}-cert-v01@openssh.com"] = value
+# TODO 4.0: rename all these - including making some of their methods public?
class AgentSSH:
-
def __init__(self):
self._conn = None
self._keys = ()
@@ -47,7 +76,47 @@ class AgentSSH:
a tuple of `.AgentKey` objects representing keys available on the
SSH agent
"""
- pass
+ return self._keys
+
+ def _connect(self, conn):
+ self._conn = conn
+ ptype, result = self._send_message(cSSH2_AGENTC_REQUEST_IDENTITIES)
+ if ptype != SSH2_AGENT_IDENTITIES_ANSWER:
+ raise SSHException("could not get keys from ssh-agent")
+ keys = []
+ for i in range(result.get_int()):
+ keys.append(
+ AgentKey(
+ agent=self,
+ blob=result.get_binary(),
+ comment=result.get_text(),
+ )
+ )
+ self._keys = tuple(keys)
+
+ def _close(self):
+ if self._conn is not None:
+ self._conn.close()
+ self._conn = None
+ self._keys = ()
+
+ def _send_message(self, msg):
+ msg = asbytes(msg)
+ self._conn.send(struct.pack(">I", len(msg)) + msg)
+ data = self._read_all(4)
+ msg = Message(self._read_all(struct.unpack(">I", data)[0]))
+ return ord(msg.get_byte()), msg
+
+ def _read_all(self, wanted):
+ result = self._conn.recv(wanted)
+ while len(result) < wanted:
+ if len(result) == 0:
+ raise SSHException("lost ssh-agent")
+ extra = self._conn.recv(wanted - len(result))
+ if len(extra) == 0:
+ raise SSHException("lost ssh-agent")
+ result += extra
+ return result
class AgentProxyThread(threading.Thread):
@@ -60,6 +129,54 @@ class AgentProxyThread(threading.Thread):
self._agent = agent
self._exit = False
+ def run(self):
+ try:
+ (r, addr) = self.get_connection()
+ # Found that r should be either
+ # a socket from the socket library or None
+ self.__inr = r
+ # The address should be an IP address as a string? or None
+ self.__addr = addr
+ self._agent.connect()
+ if not isinstance(self._agent, int) and (
+ self._agent._conn is None
+ or not hasattr(self._agent._conn, "fileno")
+ ):
+ raise AuthenticationException("Unable to connect to SSH agent")
+ self._communicate()
+ except:
+ # XXX Not sure what to do here ... raise or pass ?
+ raise
+
+ def _communicate(self):
+ import fcntl
+
+ oldflags = fcntl.fcntl(self.__inr, fcntl.F_GETFL)
+ fcntl.fcntl(self.__inr, fcntl.F_SETFL, oldflags | os.O_NONBLOCK)
+ while not self._exit:
+ events = select([self._agent._conn, self.__inr], [], [], 0.5)
+ for fd in events[0]:
+ if self._agent._conn == fd:
+ data = self._agent._conn.recv(512)
+ if len(data) != 0:
+ self.__inr.send(data)
+ else:
+ self._close()
+ break
+ elif self.__inr == fd:
+ data = self.__inr.recv(512)
+ if len(data) != 0:
+ self._agent._conn.send(data)
+ else:
+ self._close()
+ break
+ time.sleep(io_sleep)
+
+ def _close(self):
+ self._exit = True
+ self.__inr.close()
+ self._agent._conn.close()
+
class AgentLocalProxy(AgentProxyThread):
"""
@@ -76,7 +193,14 @@ class AgentLocalProxy(AgentProxyThread):
May block!
"""
- pass
+ conn = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+ try:
+ conn.bind(self._agent._get_filename())
+ conn.listen(1)
+ (r, addr) = conn.accept()
+ return r, addr
+ except:
+ raise
class AgentRemoteProxy(AgentProxyThread):
@@ -88,6 +212,9 @@ class AgentRemoteProxy(AgentProxyThread):
AgentProxyThread.__init__(self, agent)
self.__chan = chan
+ def get_connection(self):
+ return self.__chan, None
+
def get_agent_connection():
"""
@@ -95,7 +222,26 @@ def get_agent_connection():
.. versionadded:: 2.10
"""
- pass
+ if ("SSH_AUTH_SOCK" in os.environ) and (sys.platform != "win32"):
+ conn = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+ try:
+ conn.connect(os.environ["SSH_AUTH_SOCK"])
+ return conn
+ except:
+ # probably a dangling env var: the ssh agent is gone
+ return
+ elif sys.platform == "win32":
+ from . import win_pageant, win_openssh
+
+ conn = None
+ if win_pageant.can_talk_to_agent():
+ conn = win_pageant.PageantConnection()
+ elif win_openssh.can_talk_to_agent():
+ conn = win_openssh.OpenSSHAgentConnection()
+ return conn
+ else:
+ # no agent support
+ return
class AgentClientProxy:
@@ -124,14 +270,21 @@ class AgentClientProxy:
"""
Method automatically called by ``AgentProxyThread.run``.
"""
- pass
+ conn = get_agent_connection()
+ if not conn:
+ return
+ self._conn = conn
def close(self):
"""
Close the current connection and terminate the agent
Should be called manually
"""
- pass
+ if hasattr(self, "thread"):
+ self.thread._exit = True
+ self.thread.join(1000)
+ if self._conn is not None:
+ self._conn.close()
class AgentServerProxy(AgentSSH):
@@ -155,21 +308,32 @@ class AgentServerProxy(AgentSSH):
def __init__(self, t):
AgentSSH.__init__(self)
self.__t = t
- self._dir = tempfile.mkdtemp('sshproxy')
+ self._dir = tempfile.mkdtemp("sshproxy")
os.chmod(self._dir, stat.S_IRWXU)
- self._file = self._dir + '/sshproxy.ssh'
+ self._file = self._dir + "/sshproxy.ssh"
self.thread = AgentLocalProxy(self)
self.thread.start()
def __del__(self):
self.close()
+ def connect(self):
+ conn_sock = self.__t.open_forward_agent_channel()
+ if conn_sock is None:
+ raise SSHException("lost ssh-agent")
+ conn_sock.set_name("auth-agent")
+ self._connect(conn_sock)
+
def close(self):
"""
Terminate the agent, clean the files, close connections
Should be called manually
"""
- pass
+ os.remove(self._file)
+ os.rmdir(self._dir)
+ self.thread._exit = True
+ self.thread.join(1000)
+ self._close()
def get_env(self):
"""
@@ -178,7 +342,10 @@ class AgentServerProxy(AgentSSH):
:return:
a dict containing the ``SSH_AUTH_SOCK`` environment variables
"""
- pass
+ return {"SSH_AUTH_SOCK": self._get_filename()}
+
+ def _get_filename(self):
+ return self._file
class AgentRequestHandler:
@@ -209,9 +376,16 @@ class AgentRequestHandler:
chanClient.request_forward_agent(self._forward_agent_handler)
self.__clientProxys = []
+ def _forward_agent_handler(self, chanRemote):
+ self.__clientProxys.append(AgentClientProxy(chanRemote))
+
def __del__(self):
self.close()
+ def close(self):
+ for p in self.__clientProxys:
+ p.close()
+
class Agent(AgentSSH):
"""
@@ -234,6 +408,7 @@ class Agent(AgentSSH):
def __init__(self):
AgentSSH.__init__(self)
+
conn = get_agent_connection()
if not conn:
return
@@ -243,7 +418,7 @@ class Agent(AgentSSH):
"""
Close the SSH agent connection.
"""
- pass
+ self._close()
class AgentKey(PKey):
@@ -260,7 +435,7 @@ class AgentKey(PKey):
key instance this key is a proxy for, if one was obtainable, else None.
"""
- def __init__(self, agent, blob, comment=''):
+ def __init__(self, agent, blob, comment=""):
self.agent = agent
self.blob = blob
self.comment = comment
@@ -269,16 +444,54 @@ class AgentKey(PKey):
self._logger = get_logger(__file__)
self.inner_key = None
try:
- self.inner_key = PKey.from_type_string(key_type=self.name,
- key_bytes=blob)
+ self.inner_key = PKey.from_type_string(
+ key_type=self.name, key_bytes=blob
+ )
except UnknownKeyType:
- err = 'Unable to derive inner_key for agent key of type {!r}'
+ # Log, but don't explode, since inner_key is a best-effort thing.
+ err = "Unable to derive inner_key for agent key of type {!r}"
self.log(DEBUG, err.format(self.name))
+ def log(self, *args, **kwargs):
+ return self._logger.log(*args, **kwargs)
+
+ def asbytes(self):
+ # Prefer inner_key.asbytes, since that will differ for eg RSA-CERT
+ return self.inner_key.asbytes() if self.inner_key else self.blob
+
+ def get_name(self):
+ return self.name
+
+ def get_bits(self):
+ # Have to work around PKey's default get_bits being crap
+ if self.inner_key is not None:
+ return self.inner_key.get_bits()
+ return super().get_bits()
+
def __getattr__(self, name):
"""
Proxy any un-implemented methods/properties to the inner_key.
"""
- if self.inner_key is None:
+ if self.inner_key is None: # nothing to proxy to
raise AttributeError(name)
return getattr(self.inner_key, name)
+
+ @property
+ def _fields(self):
+ fallback = [self.get_name(), self.blob]
+ return self.inner_key._fields if self.inner_key else fallback
+
+ def sign_ssh_data(self, data, algorithm=None):
+ msg = Message()
+ msg.add_byte(cSSH2_AGENTC_SIGN_REQUEST)
+ # NOTE: this used to be just self.blob, which is not entirely right for
+ # RSA-CERT 'keys' - those end up always degrading to ssh-rsa type
+ # signatures, for reasons probably internal to OpenSSH's agent code,
+ # even if everything else wants SHA2 (including our flag map).
+ msg.add_string(self.asbytes())
+ msg.add_string(data)
+ msg.add_int(ALGORITHM_FLAG_MAP.get(algorithm, 0))
+ ptype, result = self.agent._send_message(msg)
+ if ptype != SSH2_AGENT_SIGN_RESPONSE:
+ raise SSHException("key cannot be used for signing")
+ return result.get_binary()
diff --git a/paramiko/auth_handler.py b/paramiko/auth_handler.py
index 679ae516..bc7f298f 100644
--- a/paramiko/auth_handler.py
+++ b/paramiko/auth_handler.py
@@ -1,14 +1,75 @@
+# Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
"""
`.AuthHandler`
"""
+
import weakref
import threading
import time
import re
-from paramiko.common import cMSG_SERVICE_REQUEST, cMSG_DISCONNECT, DISCONNECT_SERVICE_NOT_AVAILABLE, DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE, cMSG_USERAUTH_REQUEST, cMSG_SERVICE_ACCEPT, DEBUG, AUTH_SUCCESSFUL, INFO, cMSG_USERAUTH_SUCCESS, cMSG_USERAUTH_FAILURE, AUTH_PARTIALLY_SUCCESSFUL, cMSG_USERAUTH_INFO_REQUEST, WARNING, AUTH_FAILED, cMSG_USERAUTH_PK_OK, cMSG_USERAUTH_INFO_RESPONSE, MSG_SERVICE_REQUEST, MSG_SERVICE_ACCEPT, MSG_USERAUTH_REQUEST, MSG_USERAUTH_SUCCESS, MSG_USERAUTH_FAILURE, MSG_USERAUTH_BANNER, MSG_USERAUTH_INFO_REQUEST, MSG_USERAUTH_INFO_RESPONSE, cMSG_USERAUTH_GSSAPI_RESPONSE, cMSG_USERAUTH_GSSAPI_TOKEN, cMSG_USERAUTH_GSSAPI_MIC, MSG_USERAUTH_GSSAPI_RESPONSE, MSG_USERAUTH_GSSAPI_TOKEN, MSG_USERAUTH_GSSAPI_ERROR, MSG_USERAUTH_GSSAPI_ERRTOK, MSG_USERAUTH_GSSAPI_MIC, MSG_NAMES, cMSG_USERAUTH_BANNER
+
+from paramiko.common import (
+ cMSG_SERVICE_REQUEST,
+ cMSG_DISCONNECT,
+ DISCONNECT_SERVICE_NOT_AVAILABLE,
+ DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE,
+ cMSG_USERAUTH_REQUEST,
+ cMSG_SERVICE_ACCEPT,
+ DEBUG,
+ AUTH_SUCCESSFUL,
+ INFO,
+ cMSG_USERAUTH_SUCCESS,
+ cMSG_USERAUTH_FAILURE,
+ AUTH_PARTIALLY_SUCCESSFUL,
+ cMSG_USERAUTH_INFO_REQUEST,
+ WARNING,
+ AUTH_FAILED,
+ cMSG_USERAUTH_PK_OK,
+ cMSG_USERAUTH_INFO_RESPONSE,
+ MSG_SERVICE_REQUEST,
+ MSG_SERVICE_ACCEPT,
+ MSG_USERAUTH_REQUEST,
+ MSG_USERAUTH_SUCCESS,
+ MSG_USERAUTH_FAILURE,
+ MSG_USERAUTH_BANNER,
+ MSG_USERAUTH_INFO_REQUEST,
+ MSG_USERAUTH_INFO_RESPONSE,
+ cMSG_USERAUTH_GSSAPI_RESPONSE,
+ cMSG_USERAUTH_GSSAPI_TOKEN,
+ cMSG_USERAUTH_GSSAPI_MIC,
+ MSG_USERAUTH_GSSAPI_RESPONSE,
+ MSG_USERAUTH_GSSAPI_TOKEN,
+ MSG_USERAUTH_GSSAPI_ERROR,
+ MSG_USERAUTH_GSSAPI_ERRTOK,
+ MSG_USERAUTH_GSSAPI_MIC,
+ MSG_NAMES,
+ cMSG_USERAUTH_BANNER,
+)
from paramiko.message import Message
from paramiko.util import b, u
-from paramiko.ssh_exception import SSHException, AuthenticationException, BadAuthenticationType, PartialAuthentication
+from paramiko.ssh_exception import (
+ SSHException,
+ AuthenticationException,
+ BadAuthenticationType,
+ PartialAuthentication,
+)
from paramiko.server import InteractiveQuery
from paramiko.ssh_gss import GSSAuth, GSS_EXCEPTIONS
@@ -23,22 +84,129 @@ class AuthHandler:
self.username = None
self.authenticated = False
self.auth_event = None
- self.auth_method = ''
+ self.auth_method = ""
self.banner = None
self.password = None
self.private_key = None
self.interactive_handler = None
self.submethods = None
+ # for server mode:
self.auth_username = None
self.auth_fail_count = 0
+ # for GSSAPI
self.gss_host = None
self.gss_deleg_creds = True
- def auth_interactive(self, username, handler, event, submethods=''):
+ def _log(self, *args):
+ return self.transport._log(*args)
+
+ def is_authenticated(self):
+ return self.authenticated
+
+ def get_username(self):
+ if self.transport.server_mode:
+ return self.auth_username
+ else:
+ return self.username
+
+ def auth_none(self, username, event):
+ self.transport.lock.acquire()
+ try:
+ self.auth_event = event
+ self.auth_method = "none"
+ self.username = username
+ self._request_auth()
+ finally:
+ self.transport.lock.release()
+
+ def auth_publickey(self, username, key, event):
+ self.transport.lock.acquire()
+ try:
+ self.auth_event = event
+ self.auth_method = "publickey"
+ self.username = username
+ self.private_key = key
+ self._request_auth()
+ finally:
+ self.transport.lock.release()
+
+ def auth_password(self, username, password, event):
+ self.transport.lock.acquire()
+ try:
+ self.auth_event = event
+ self.auth_method = "password"
+ self.username = username
+ self.password = password
+ self._request_auth()
+ finally:
+ self.transport.lock.release()
+
+ def auth_interactive(self, username, handler, event, submethods=""):
"""
response_list = handler(title, instructions, prompt_list)
"""
- pass
+ self.transport.lock.acquire()
+ try:
+ self.auth_event = event
+ self.auth_method = "keyboard-interactive"
+ self.username = username
+ self.interactive_handler = handler
+ self.submethods = submethods
+ self._request_auth()
+ finally:
+ self.transport.lock.release()
+
+ def auth_gssapi_with_mic(self, username, gss_host, gss_deleg_creds, event):
+ self.transport.lock.acquire()
+ try:
+ self.auth_event = event
+ self.auth_method = "gssapi-with-mic"
+ self.username = username
+ self.gss_host = gss_host
+ self.gss_deleg_creds = gss_deleg_creds
+ self._request_auth()
+ finally:
+ self.transport.lock.release()
+
+ def auth_gssapi_keyex(self, username, event):
+ self.transport.lock.acquire()
+ try:
+ self.auth_event = event
+ self.auth_method = "gssapi-keyex"
+ self.username = username
+ self._request_auth()
+ finally:
+ self.transport.lock.release()
+
+ def abort(self):
+ if self.auth_event is not None:
+ self.auth_event.set()
+
+ # ...internals...
+
+ def _request_auth(self):
+ m = Message()
+ m.add_byte(cMSG_SERVICE_REQUEST)
+ m.add_string("ssh-userauth")
+ self.transport._send_message(m)
+
+ def _disconnect_service_not_available(self):
+ m = Message()
+ m.add_byte(cMSG_DISCONNECT)
+ m.add_int(DISCONNECT_SERVICE_NOT_AVAILABLE)
+ m.add_string("Service not available")
+ m.add_string("en")
+ self.transport._send_message(m)
+ self.transport.close()
+
+ def _disconnect_no_more_auth(self):
+ m = Message()
+ m.add_byte(cMSG_DISCONNECT)
+ m.add_int(DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE)
+ m.add_string("No more auth methods available")
+ m.add_string("en")
+ self.transport._send_message(m)
+ self.transport.close()
def _get_key_type_and_bits(self, key):
"""
@@ -46,7 +214,647 @@ class AuthHandler:
Intended for input to or verification of, key signatures.
"""
- pass
+ # Use certificate contents, if available, plain pubkey otherwise
+ if key.public_blob:
+ return key.public_blob.key_type, key.public_blob.key_blob
+ else:
+ return key.get_name(), key
+
+ def _get_session_blob(self, key, service, username, algorithm):
+ m = Message()
+ m.add_string(self.transport.session_id)
+ m.add_byte(cMSG_USERAUTH_REQUEST)
+ m.add_string(username)
+ m.add_string(service)
+ m.add_string("publickey")
+ m.add_boolean(True)
+ _, bits = self._get_key_type_and_bits(key)
+ m.add_string(algorithm)
+ m.add_string(bits)
+ return m.asbytes()
+
+ def wait_for_response(self, event):
+ max_ts = None
+ if self.transport.auth_timeout is not None:
+ max_ts = time.time() + self.transport.auth_timeout
+ while True:
+ event.wait(0.1)
+ if not self.transport.is_active():
+ e = self.transport.get_exception()
+ if (e is None) or issubclass(e.__class__, EOFError):
+ e = AuthenticationException(
+ "Authentication failed: transport shut down or saw EOF"
+ )
+ raise e
+ if event.is_set():
+ break
+ if max_ts is not None and max_ts <= time.time():
+ raise AuthenticationException("Authentication timeout.")
+
+ if not self.is_authenticated():
+ e = self.transport.get_exception()
+ if e is None:
+ e = AuthenticationException("Authentication failed.")
+ # this is horrible. Python Exception isn't yet descended from
+ # object, so type(e) won't work. :(
+ # TODO 4.0: lol. just lmao.
+ if issubclass(e.__class__, PartialAuthentication):
+ return e.allowed_types
+ raise e
+ return []
+
+ def _parse_service_request(self, m):
+ service = m.get_text()
+ if self.transport.server_mode and (service == "ssh-userauth"):
+ # accepted
+ m = Message()
+ m.add_byte(cMSG_SERVICE_ACCEPT)
+ m.add_string(service)
+ self.transport._send_message(m)
+ banner, language = self.transport.server_object.get_banner()
+ if banner:
+ m = Message()
+ m.add_byte(cMSG_USERAUTH_BANNER)
+ m.add_string(banner)
+ m.add_string(language)
+ self.transport._send_message(m)
+ return
+ # dunno this one
+ self._disconnect_service_not_available()
+
+ def _generate_key_from_request(self, algorithm, keyblob):
+ # For use in server mode.
+ options = self.transport.preferred_pubkeys
+ if algorithm.replace("-cert-v01@openssh.com", "") not in options:
+ err = (
+ "Auth rejected: pubkey algorithm '{}' unsupported or disabled"
+ )
+ self._log(INFO, err.format(algorithm))
+ return None
+ return self.transport._key_info[algorithm](Message(keyblob))
+
+ def _choose_fallback_pubkey_algorithm(self, key_type, my_algos):
+ # Fallback: first one in our (possibly tweaked by caller) list
+ pubkey_algo = my_algos[0]
+ msg = "Server did not send a server-sig-algs list; defaulting to our first preferred algo ({!r})" # noqa
+ self._log(DEBUG, msg.format(pubkey_algo))
+ self._log(
+ DEBUG,
+ "NOTE: you may use the 'disabled_algorithms' SSHClient/Transport init kwarg to disable that or other algorithms if your server does not support them!", # noqa
+ )
+ return pubkey_algo
+
+ def _finalize_pubkey_algorithm(self, key_type):
+ # Short-circuit for non-RSA keys
+ if "rsa" not in key_type:
+ return key_type
+ self._log(
+ DEBUG,
+ "Finalizing pubkey algorithm for key of type {!r}".format(
+ key_type
+ ),
+ )
+ # NOTE re #2017: When the key is an RSA cert and the remote server is
+ # OpenSSH 7.7 or earlier, always use ssh-rsa-cert-v01@openssh.com.
+ # Those versions of the server won't support rsa-sha2 family sig algos
+ # for certs specifically, and in tandem with various server bugs
+ # regarding server-sig-algs, it's impossible to fit this into the rest
+ # of the logic here.
+ if key_type.endswith("-cert-v01@openssh.com") and re.search(
+ r"-OpenSSH_(?:[1-6]|7\.[0-7])", self.transport.remote_version
+ ):
+ pubkey_algo = "ssh-rsa-cert-v01@openssh.com"
+ self.transport._agreed_pubkey_algorithm = pubkey_algo
+ self._log(DEBUG, "OpenSSH<7.8 + RSA cert = forcing ssh-rsa!")
+ self._log(
+ DEBUG, "Agreed upon {!r} pubkey algorithm".format(pubkey_algo)
+ )
+ return pubkey_algo
+ # Normal attempts to handshake follow from here.
+ # Only consider RSA algos from our list, lest we agree on another!
+ my_algos = [x for x in self.transport.preferred_pubkeys if "rsa" in x]
+ self._log(DEBUG, "Our pubkey algorithm list: {}".format(my_algos))
+ # Short-circuit negatively if user disabled all RSA algos (heh)
+ if not my_algos:
+ raise SSHException(
+ "An RSA key was specified, but no RSA pubkey algorithms are configured!" # noqa
+ )
+ # Check for server-sig-algs if supported & sent
+ server_algo_str = u(
+ self.transport.server_extensions.get("server-sig-algs", b(""))
+ )
+ pubkey_algo = None
+ # Prefer to match against server-sig-algs
+ if server_algo_str:
+ server_algos = server_algo_str.split(",")
+ self._log(
+ DEBUG, "Server-side algorithm list: {}".format(server_algos)
+ )
+ # Only use algos from our list that the server likes, in our own
+ # preference order. (NOTE: purposefully using same style as in
+ # Transport...expect to refactor later)
+ agreement = list(filter(server_algos.__contains__, my_algos))
+ if agreement:
+ pubkey_algo = agreement[0]
+ self._log(
+ DEBUG,
+ "Agreed upon {!r} pubkey algorithm".format(pubkey_algo),
+ )
+ else:
+ self._log(DEBUG, "No common pubkey algorithms exist! Dying.")
+ # TODO: MAY want to use IncompatiblePeer again here but that's
+ # technically for initial key exchange, not pubkey auth.
+ err = "Unable to agree on a pubkey algorithm for signing a {!r} key!" # noqa
+ raise AuthenticationException(err.format(key_type))
+ # Fallback to something based purely on the key & our configuration
+ else:
+ pubkey_algo = self._choose_fallback_pubkey_algorithm(
+ key_type, my_algos
+ )
+ if key_type.endswith("-cert-v01@openssh.com"):
+ pubkey_algo += "-cert-v01@openssh.com"
+ self.transport._agreed_pubkey_algorithm = pubkey_algo
+ return pubkey_algo
+
+ def _parse_service_accept(self, m):
+ service = m.get_text()
+ if service == "ssh-userauth":
+ self._log(DEBUG, "userauth is OK")
+ m = Message()
+ m.add_byte(cMSG_USERAUTH_REQUEST)
+ m.add_string(self.username)
+ m.add_string("ssh-connection")
+ m.add_string(self.auth_method)
+ if self.auth_method == "password":
+ m.add_boolean(False)
+ password = b(self.password)
+ m.add_string(password)
+ elif self.auth_method == "publickey":
+ m.add_boolean(True)
+ key_type, bits = self._get_key_type_and_bits(self.private_key)
+ algorithm = self._finalize_pubkey_algorithm(key_type)
+ m.add_string(algorithm)
+ m.add_string(bits)
+ blob = self._get_session_blob(
+ self.private_key,
+ "ssh-connection",
+ self.username,
+ algorithm,
+ )
+ sig = self.private_key.sign_ssh_data(blob, algorithm)
+ m.add_string(sig)
+ elif self.auth_method == "keyboard-interactive":
+ m.add_string("")
+ m.add_string(self.submethods)
+ elif self.auth_method == "gssapi-with-mic":
+ sshgss = GSSAuth(self.auth_method, self.gss_deleg_creds)
+ m.add_bytes(sshgss.ssh_gss_oids())
+ # send the supported GSSAPI OIDs to the server
+ self.transport._send_message(m)
+ ptype, m = self.transport.packetizer.read_message()
+ if ptype == MSG_USERAUTH_BANNER:
+ self._parse_userauth_banner(m)
+ ptype, m = self.transport.packetizer.read_message()
+ if ptype == MSG_USERAUTH_GSSAPI_RESPONSE:
+ # Read the mechanism selected by the server. We send just
+ # the Kerberos V5 OID, so the server can only respond with
+ # this OID.
+ mech = m.get_string()
+ m = Message()
+ m.add_byte(cMSG_USERAUTH_GSSAPI_TOKEN)
+ try:
+ m.add_string(
+ sshgss.ssh_init_sec_context(
+ self.gss_host, mech, self.username
+ )
+ )
+ except GSS_EXCEPTIONS as e:
+ return self._handle_local_gss_failure(e)
+ self.transport._send_message(m)
+ while True:
+ ptype, m = self.transport.packetizer.read_message()
+ if ptype == MSG_USERAUTH_GSSAPI_TOKEN:
+ srv_token = m.get_string()
+ try:
+ next_token = sshgss.ssh_init_sec_context(
+ self.gss_host,
+ mech,
+ self.username,
+ srv_token,
+ )
+ except GSS_EXCEPTIONS as e:
+ return self._handle_local_gss_failure(e)
+ # After this step the GSSAPI should not return any
+ # token. If it does, we keep sending the token to
+ # the server until no more token is returned.
+ if next_token is None:
+ break
+ else:
+ m = Message()
+ m.add_byte(cMSG_USERAUTH_GSSAPI_TOKEN)
+ m.add_string(next_token)
+ self.transport.send_message(m)
+ else:
+ raise SSHException(
+ "Received Package: {}".format(MSG_NAMES[ptype])
+ )
+ m = Message()
+ m.add_byte(cMSG_USERAUTH_GSSAPI_MIC)
+ # send the MIC to the server
+ m.add_string(sshgss.ssh_get_mic(self.transport.session_id))
+ elif ptype == MSG_USERAUTH_GSSAPI_ERRTOK:
+ # RFC 4462 says we are not required to implement GSS-API
+ # error messages.
+ # See RFC 4462 Section 3.8 in
+ # http://www.ietf.org/rfc/rfc4462.txt
+ raise SSHException("Server returned an error token")
+ elif ptype == MSG_USERAUTH_GSSAPI_ERROR:
+ maj_status = m.get_int()
+ min_status = m.get_int()
+ err_msg = m.get_string()
+ m.get_string() # Lang tag - discarded
+ raise SSHException(
+ """GSS-API Error:
+Major Status: {}
+Minor Status: {}
+Error Message: {}
+""".format(
+ maj_status, min_status, err_msg
+ )
+ )
+ elif ptype == MSG_USERAUTH_FAILURE:
+ self._parse_userauth_failure(m)
+ return
+ else:
+ raise SSHException(
+ "Received Package: {}".format(MSG_NAMES[ptype])
+ )
+ elif (
+ self.auth_method == "gssapi-keyex"
+ and self.transport.gss_kex_used
+ ):
+ kexgss = self.transport.kexgss_ctxt
+ kexgss.set_username(self.username)
+ mic_token = kexgss.ssh_get_mic(self.transport.session_id)
+ m.add_string(mic_token)
+ elif self.auth_method == "none":
+ pass
+ else:
+ raise SSHException(
+ 'Unknown auth method "{}"'.format(self.auth_method)
+ )
+ self.transport._send_message(m)
+ else:
+ self._log(
+ DEBUG, 'Service request "{}" accepted (?)'.format(service)
+ )
+
+ def _send_auth_result(self, username, method, result):
+ # okay, send result
+ m = Message()
+ if result == AUTH_SUCCESSFUL:
+ self._log(INFO, "Auth granted ({}).".format(method))
+ m.add_byte(cMSG_USERAUTH_SUCCESS)
+ self.authenticated = True
+ else:
+ self._log(INFO, "Auth rejected ({}).".format(method))
+ m.add_byte(cMSG_USERAUTH_FAILURE)
+ m.add_string(
+ self.transport.server_object.get_allowed_auths(username)
+ )
+ if result == AUTH_PARTIALLY_SUCCESSFUL:
+ m.add_boolean(True)
+ else:
+ m.add_boolean(False)
+ self.auth_fail_count += 1
+ self.transport._send_message(m)
+ if self.auth_fail_count >= 10:
+ self._disconnect_no_more_auth()
+ if result == AUTH_SUCCESSFUL:
+ self.transport._auth_trigger()
+
+ def _interactive_query(self, q):
+ # make interactive query instead of response
+ m = Message()
+ m.add_byte(cMSG_USERAUTH_INFO_REQUEST)
+ m.add_string(q.name)
+ m.add_string(q.instructions)
+ m.add_string(bytes())
+ m.add_int(len(q.prompts))
+ for p in q.prompts:
+ m.add_string(p[0])
+ m.add_boolean(p[1])
+ self.transport._send_message(m)
+
+ def _parse_userauth_request(self, m):
+ if not self.transport.server_mode:
+ # er, uh... what?
+ m = Message()
+ m.add_byte(cMSG_USERAUTH_FAILURE)
+ m.add_string("none")
+ m.add_boolean(False)
+ self.transport._send_message(m)
+ return
+ if self.authenticated:
+ # ignore
+ return
+ username = m.get_text()
+ service = m.get_text()
+ method = m.get_text()
+ self._log(
+ DEBUG,
+ "Auth request (type={}) service={}, username={}".format(
+ method, service, username
+ ),
+ )
+ if service != "ssh-connection":
+ self._disconnect_service_not_available()
+ return
+ if (self.auth_username is not None) and (
+ self.auth_username != username
+ ):
+ self._log(
+ WARNING,
+ "Auth rejected because the client attempted to change username in mid-flight", # noqa
+ )
+ self._disconnect_no_more_auth()
+ return
+ self.auth_username = username
+ # check if GSS-API authentication is enabled
+ gss_auth = self.transport.server_object.enable_auth_gssapi()
+
+ if method == "none":
+ result = self.transport.server_object.check_auth_none(username)
+ elif method == "password":
+ changereq = m.get_boolean()
+ password = m.get_binary()
+ try:
+ password = password.decode("UTF-8")
+ except UnicodeError:
+ # some clients/servers expect non-utf-8 passwords!
+ # in this case, just return the raw byte string.
+ pass
+ if changereq:
+ # always treated as failure, since we don't support changing
+ # passwords, but collect the list of valid auth types from
+ # the callback anyway
+ self._log(DEBUG, "Auth request to change passwords (rejected)")
+ newpassword = m.get_binary()
+ try:
+ newpassword = newpassword.decode("UTF-8", "replace")
+ except UnicodeError:
+ pass
+ result = AUTH_FAILED
+ else:
+ result = self.transport.server_object.check_auth_password(
+ username, password
+ )
+ elif method == "publickey":
+ sig_attached = m.get_boolean()
+ # NOTE: server never wants to guess a client's algo, they're
+ # telling us directly. No need for _finalize_pubkey_algorithm
+ # anywhere in this flow.
+ algorithm = m.get_text()
+ keyblob = m.get_binary()
+ try:
+ key = self._generate_key_from_request(algorithm, keyblob)
+ except SSHException as e:
+ self._log(INFO, "Auth rejected: public key: {}".format(str(e)))
+ key = None
+ except Exception as e:
+ msg = "Auth rejected: unsupported or mangled public key ({}: {})" # noqa
+ self._log(INFO, msg.format(e.__class__.__name__, e))
+ key = None
+ if key is None:
+ self._disconnect_no_more_auth()
+ return
+ # first check if this key is okay... if not, we can skip the verify
+ result = self.transport.server_object.check_auth_publickey(
+ username, key
+ )
+ if result != AUTH_FAILED:
+ # key is okay, verify it
+ if not sig_attached:
+ # client wants to know if this key is acceptable, before it
+ # signs anything... send special "ok" message
+ m = Message()
+ m.add_byte(cMSG_USERAUTH_PK_OK)
+ m.add_string(algorithm)
+ m.add_string(keyblob)
+ self.transport._send_message(m)
+ return
+ sig = Message(m.get_binary())
+ blob = self._get_session_blob(
+ key, service, username, algorithm
+ )
+ if not key.verify_ssh_sig(blob, sig):
+ self._log(INFO, "Auth rejected: invalid signature")
+ result = AUTH_FAILED
+ elif method == "keyboard-interactive":
+ submethods = m.get_string()
+ result = self.transport.server_object.check_auth_interactive(
+ username, submethods
+ )
+ if isinstance(result, InteractiveQuery):
+ # make interactive query instead of response
+ self._interactive_query(result)
+ return
+ elif method == "gssapi-with-mic" and gss_auth:
+ sshgss = GSSAuth(method)
+ # Read the number of OID mechanisms supported by the client.
+ # OpenSSH sends just one OID. It's the Kerveros V5 OID and that's
+ # the only OID we support.
+ mechs = m.get_int()
+ # We can't accept more than one OID, so if the SSH client sends
+ # more than one, disconnect.
+ if mechs > 1:
+ self._log(
+ INFO,
+ "Disconnect: Received more than one GSS-API OID mechanism",
+ )
+ self._disconnect_no_more_auth()
+ desired_mech = m.get_string()
+ mech_ok = sshgss.ssh_check_mech(desired_mech)
+ # if we don't support the mechanism, disconnect.
+ if not mech_ok:
+ self._log(
+ INFO,
+ "Disconnect: Received an invalid GSS-API OID mechanism",
+ )
+ self._disconnect_no_more_auth()
+ # send the Kerberos V5 GSSAPI OID to the client
+ supported_mech = sshgss.ssh_gss_oids("server")
+ # RFC 4462 says we are not required to implement GSS-API error
+ # messages. See section 3.8 in http://www.ietf.org/rfc/rfc4462.txt
+ m = Message()
+ m.add_byte(cMSG_USERAUTH_GSSAPI_RESPONSE)
+ m.add_bytes(supported_mech)
+ self.transport.auth_handler = GssapiWithMicAuthHandler(
+ self, sshgss
+ )
+ self.transport._expected_packet = (
+ MSG_USERAUTH_GSSAPI_TOKEN,
+ MSG_USERAUTH_REQUEST,
+ MSG_SERVICE_REQUEST,
+ )
+ self.transport._send_message(m)
+ return
+ elif method == "gssapi-keyex" and gss_auth:
+ mic_token = m.get_string()
+ sshgss = self.transport.kexgss_ctxt
+ if sshgss is None:
+ # If there is no valid context, we reject the authentication
+ result = AUTH_FAILED
+ self._send_auth_result(username, method, result)
+ try:
+ sshgss.ssh_check_mic(
+ mic_token, self.transport.session_id, self.auth_username
+ )
+ except Exception:
+ result = AUTH_FAILED
+ self._send_auth_result(username, method, result)
+ raise
+ result = AUTH_SUCCESSFUL
+ self.transport.server_object.check_auth_gssapi_keyex(
+ username, result
+ )
+ else:
+ result = self.transport.server_object.check_auth_none(username)
+ # okay, send result
+ self._send_auth_result(username, method, result)
+
+ def _parse_userauth_success(self, m):
+ self._log(
+ INFO, "Authentication ({}) successful!".format(self.auth_method)
+ )
+ self.authenticated = True
+ self.transport._auth_trigger()
+ if self.auth_event is not None:
+ self.auth_event.set()
+
+ def _parse_userauth_failure(self, m):
+ authlist = m.get_list()
+ # TODO 4.0: we aren't giving callers access to authlist _unless_ it's
+ # partial authentication, so eg authtype=none can't work unless we
+ # tweak this.
+ partial = m.get_boolean()
+ if partial:
+ self._log(INFO, "Authentication continues...")
+ self._log(DEBUG, "Methods: " + str(authlist))
+ self.transport.saved_exception = PartialAuthentication(authlist)
+ elif self.auth_method not in authlist:
+ for msg in (
+ "Authentication type ({}) not permitted.".format(
+ self.auth_method
+ ),
+ "Allowed methods: {}".format(authlist),
+ ):
+ self._log(DEBUG, msg)
+ self.transport.saved_exception = BadAuthenticationType(
+ "Bad authentication type", authlist
+ )
+ else:
+ self._log(
+ INFO, "Authentication ({}) failed.".format(self.auth_method)
+ )
+ self.authenticated = False
+ self.username = None
+ if self.auth_event is not None:
+ self.auth_event.set()
+
+ def _parse_userauth_banner(self, m):
+ banner = m.get_string()
+ self.banner = banner
+ self._log(INFO, "Auth banner: {}".format(banner))
+ # who cares.
+
+ def _parse_userauth_info_request(self, m):
+ if self.auth_method != "keyboard-interactive":
+ raise SSHException("Illegal info request from server")
+ title = m.get_text()
+ instructions = m.get_text()
+ m.get_binary() # lang
+ prompts = m.get_int()
+ prompt_list = []
+ for i in range(prompts):
+ prompt_list.append((m.get_text(), m.get_boolean()))
+ response_list = self.interactive_handler(
+ title, instructions, prompt_list
+ )
+
+ m = Message()
+ m.add_byte(cMSG_USERAUTH_INFO_RESPONSE)
+ m.add_int(len(response_list))
+ for r in response_list:
+ m.add_string(r)
+ self.transport._send_message(m)
+
+ def _parse_userauth_info_response(self, m):
+ if not self.transport.server_mode:
+ raise SSHException("Illegal info response from server")
+ n = m.get_int()
+ responses = []
+ for i in range(n):
+ responses.append(m.get_text())
+ result = self.transport.server_object.check_auth_interactive_response(
+ responses
+ )
+ if isinstance(result, InteractiveQuery):
+ # make interactive query instead of response
+ self._interactive_query(result)
+ return
+ self._send_auth_result(
+ self.auth_username, "keyboard-interactive", result
+ )
+
+ def _handle_local_gss_failure(self, e):
+ self.transport.saved_exception = e
+ self._log(DEBUG, "GSSAPI failure: {}".format(e))
+ self._log(INFO, "Authentication ({}) failed.".format(self.auth_method))
+ self.authenticated = False
+ self.username = None
+ if self.auth_event is not None:
+ self.auth_event.set()
+ return
+
+ # TODO 4.0: MAY make sense to make these tables into actual
+ # classes/instances that can be fed a mode bool or whatever. Or,
+ # alternately (both?) make the message types small classes or enums that
+ # embed this info within themselves (which could also then tidy up the
+ # current 'integer -> human readable short string' stuff in common.py).
+ # TODO: if we do that, also expose 'em publicly.
+
+ # Messages which should be handled _by_ servers (sent by clients)
+ @property
+ def _server_handler_table(self):
+ return {
+ # TODO 4.0: MSG_SERVICE_REQUEST ought to eventually move into
+ # Transport's server mode like the client side did, just for
+ # consistency.
+ MSG_SERVICE_REQUEST: self._parse_service_request,
+ MSG_USERAUTH_REQUEST: self._parse_userauth_request,
+ MSG_USERAUTH_INFO_RESPONSE: self._parse_userauth_info_response,
+ }
+
+ # Messages which should be handled _by_ clients (sent by servers)
+ @property
+ def _client_handler_table(self):
+ return {
+ MSG_SERVICE_ACCEPT: self._parse_service_accept,
+ MSG_USERAUTH_SUCCESS: self._parse_userauth_success,
+ MSG_USERAUTH_FAILURE: self._parse_userauth_failure,
+ MSG_USERAUTH_BANNER: self._parse_userauth_banner,
+ MSG_USERAUTH_INFO_REQUEST: self._parse_userauth_info_request,
+ }
+
+ # NOTE: prior to the fix for #1283, this was a static dict instead of a
+ # property. Should be backwards compatible in most/all cases.
+ @property
+ def _handler_table(self):
+ if self.transport.server_mode:
+ return self._server_handler_table
+ else:
+ return self._client_handler_table
class GssapiWithMicAuthHandler:
@@ -55,15 +863,106 @@ class GssapiWithMicAuthHandler:
During the GSSAPI token exchange we need a modified dispatch table,
because the packet type numbers are not unique.
"""
- method = 'gssapi-with-mic'
+
+ method = "gssapi-with-mic"
def __init__(self, delegate, sshgss):
self._delegate = delegate
self.sshgss = sshgss
- __handler_table = {MSG_SERVICE_REQUEST: _parse_service_request,
+
+ def abort(self):
+ self._restore_delegate_auth_handler()
+ return self._delegate.abort()
+
+ @property
+ def transport(self):
+ return self._delegate.transport
+
+ @property
+ def _send_auth_result(self):
+ return self._delegate._send_auth_result
+
+ @property
+ def auth_username(self):
+ return self._delegate.auth_username
+
+ @property
+ def gss_host(self):
+ return self._delegate.gss_host
+
+ def _restore_delegate_auth_handler(self):
+ self.transport.auth_handler = self._delegate
+
+ def _parse_userauth_gssapi_token(self, m):
+ client_token = m.get_string()
+ # use the client token as input to establish a secure
+ # context.
+ sshgss = self.sshgss
+ try:
+ token = sshgss.ssh_accept_sec_context(
+ self.gss_host, client_token, self.auth_username
+ )
+ except Exception as e:
+ self.transport.saved_exception = e
+ result = AUTH_FAILED
+ self._restore_delegate_auth_handler()
+ self._send_auth_result(self.auth_username, self.method, result)
+ raise
+ if token is not None:
+ m = Message()
+ m.add_byte(cMSG_USERAUTH_GSSAPI_TOKEN)
+ m.add_string(token)
+ self.transport._expected_packet = (
+ MSG_USERAUTH_GSSAPI_TOKEN,
+ MSG_USERAUTH_GSSAPI_MIC,
+ MSG_USERAUTH_REQUEST,
+ )
+ self.transport._send_message(m)
+
+ def _parse_userauth_gssapi_mic(self, m):
+ mic_token = m.get_string()
+ sshgss = self.sshgss
+ username = self.auth_username
+ self._restore_delegate_auth_handler()
+ try:
+ sshgss.ssh_check_mic(
+ mic_token, self.transport.session_id, username
+ )
+ except Exception as e:
+ self.transport.saved_exception = e
+ result = AUTH_FAILED
+ self._send_auth_result(username, self.method, result)
+ raise
+ # TODO: Implement client credential saving.
+ # The OpenSSH server is able to create a TGT with the delegated
+ # client credentials, but this is not supported by GSS-API.
+ result = AUTH_SUCCESSFUL
+ self.transport.server_object.check_auth_gssapi_with_mic(
+ username, result
+ )
+ # okay, send result
+ self._send_auth_result(username, self.method, result)
+
+ def _parse_service_request(self, m):
+ self._restore_delegate_auth_handler()
+ return self._delegate._parse_service_request(m)
+
+ def _parse_userauth_request(self, m):
+ self._restore_delegate_auth_handler()
+ return self._delegate._parse_userauth_request(m)
+
+ __handler_table = {
+ MSG_SERVICE_REQUEST: _parse_service_request,
MSG_USERAUTH_REQUEST: _parse_userauth_request,
MSG_USERAUTH_GSSAPI_TOKEN: _parse_userauth_gssapi_token,
- MSG_USERAUTH_GSSAPI_MIC: _parse_userauth_gssapi_mic}
+ MSG_USERAUTH_GSSAPI_MIC: _parse_userauth_gssapi_mic,
+ }
+
+ @property
+ def _handler_table(self):
+ # TODO: determine if we can cut this up like we did for the primary
+ # AuthHandler class.
+ return self.__handler_table
class AuthOnlyHandler(AuthHandler):
@@ -73,6 +972,16 @@ class AuthOnlyHandler(AuthHandler):
.. versionadded:: 3.2
"""
+ # NOTE: this purposefully duplicates some of the parent class in order to
+ # modernize, refactor, etc. The intent is that eventually we will collapse
+ # this one onto the parent in a backwards incompatible release.
+
+ @property
+ def _client_handler_table(self):
+ my_table = super()._client_handler_table.copy()
+ del my_table[MSG_SERVICE_ACCEPT]
+ return my_table
+
def send_auth_request(self, username, method, finish_message=None):
"""
Submit a userauth request message & wait for response.
@@ -85,10 +994,99 @@ class AuthOnlyHandler(AuthHandler):
which accepts a Message ``m`` and may call mutator methods on it to add
more fields.
"""
- pass
+ # Store a few things for reference in handlers, including auth failure
+ # handler (which needs to know if we were using a bad method, etc)
+ self.auth_method = method
+ self.username = username
+ # Generic userauth request fields
+ m = Message()
+ m.add_byte(cMSG_USERAUTH_REQUEST)
+ m.add_string(username)
+ m.add_string("ssh-connection")
+ m.add_string(method)
+ # Caller usually has more to say, such as injecting password, key etc
+ finish_message(m)
+ # TODO 4.0: seems odd to have the client handle the lock and not
+ # Transport; that _may_ have been an artifact of allowing user
+ # threading event injection? Regardless, we don't want to move _this_
+ # locking into Transport._send_message now, because lots of other
+ # untouched code also uses that method and we might end up
+ # double-locking (?) but 4.0 would be a good time to revisit.
+ with self.transport.lock:
+ self.transport._send_message(m)
+ # We have cut out the higher level event args, but self.auth_event is
+ # still required for self.wait_for_response to function correctly (it's
+ # the mechanism used by the auth success/failure handlers, the abort
+ # handler, and a few other spots like in gssapi.
+ # TODO: interestingly, wait_for_response itself doesn't actually
+ # enforce that its event argument and self.auth_event are the same...
+ self.auth_event = threading.Event()
+ return self.wait_for_response(self.auth_event)
+
+ def auth_none(self, username):
+ return self.send_auth_request(username, "none")
- def auth_interactive(self, username, handler, submethods=''):
+ def auth_publickey(self, username, key):
+ key_type, bits = self._get_key_type_and_bits(key)
+ algorithm = self._finalize_pubkey_algorithm(key_type)
+ blob = self._get_session_blob(
+ key,
+ "ssh-connection",
+ username,
+ algorithm,
+ )
+
+ def finish(m):
+ # This field doesn't appear to be named, but is False when querying
+ # for permission (ie knowing whether to even prompt a user for
+ # passphrase, etc) or True when just going for it. Paramiko has
+ # never bothered with the former type of message, apparently.
+ m.add_boolean(True)
+ m.add_string(algorithm)
+ m.add_string(bits)
+ m.add_string(key.sign_ssh_data(blob, algorithm))
+
+ return self.send_auth_request(username, "publickey", finish)
+
+ def auth_password(self, username, password):
+ def finish(m):
+ # Unnamed field that equates to "I am changing my password", which
+ # Paramiko clientside never supported and serverside only sort of
+ # supported.
+ m.add_boolean(False)
+ m.add_string(b(password))
+
+ return self.send_auth_request(username, "password", finish)
+
+ def auth_interactive(self, username, handler, submethods=""):
"""
response_list = handler(title, instructions, prompt_list)
"""
- pass
+ # Unlike most siblings, this auth method _does_ require other
+ # superclass handlers (eg userauth info request) to understand
+ # what's going on, so we still set some self attributes.
+ self.auth_method = "keyboard_interactive"
+ self.interactive_handler = handler
+
+ def finish(m):
+ # Empty string for deprecated language tag field, per RFC 4256:
+ # https://www.rfc-editor.org/rfc/rfc4256#section-3.1
+ m.add_string("")
+ m.add_string(submethods)
+
+ return self.send_auth_request(username, "keyboard-interactive", finish)
+
+ # NOTE: not strictly 'auth only' related, but allows users to opt-in.
+ def _choose_fallback_pubkey_algorithm(self, key_type, my_algos):
+ msg = "Server did not send a server-sig-algs list; defaulting to something in our preferred algorithms list" # noqa
+ self._log(DEBUG, msg)
+ noncert_key_type = key_type.replace("-cert-v01@openssh.com", "")
+ if key_type in my_algos or noncert_key_type in my_algos:
+ actual = key_type if key_type in my_algos else noncert_key_type
+ msg = f"Current key type, {actual!r}, is in our preferred list; using that" # noqa
+ algo = actual
+ else:
+ algo = my_algos[0]
+ msg = f"{key_type!r} not in our list - trying first list item instead, {algo!r}" # noqa
+ self._log(DEBUG, msg)
+ return algo
diff --git a/paramiko/auth_strategy.py b/paramiko/auth_strategy.py
index 318c2713..03c1d877 100644
--- a/paramiko/auth_strategy.py
+++ b/paramiko/auth_strategy.py
@@ -4,7 +4,9 @@ Modern, adaptable authentication machinery.
Replaces certain parts of `.SSHClient`. For a concrete implementation, see the
``OpenSSHAuthStrategy`` class in `Fabric <https://fabfile.org>`_.
"""
+
from collections import namedtuple
+
from .agent import AgentKey
from .util import get_logger
from .ssh_exception import AuthenticationException
@@ -22,6 +24,13 @@ class AuthSource:
def __init__(self, username):
self.username = username
+ def _repr(self, **kwargs):
+ # TODO: are there any good libs for this? maybe some helper from
+ # structlog?
+ pairs = [f"{k}={v!r}" for k, v in kwargs.items()]
+ joined = ", ".join(pairs)
+ return f"{self.__class__.__name__}({joined})"
+
def __repr__(self):
return self._repr()
@@ -29,7 +38,7 @@ class AuthSource:
"""
Perform authentication.
"""
- pass
+ raise NotImplementedError
class NoneAuth(AuthSource):
@@ -37,6 +46,9 @@ class NoneAuth(AuthSource):
Auth type "none", ie https://www.rfc-editor.org/rfc/rfc4252#section-5.2 .
"""
+ def authenticate(self, transport):
+ return transport.auth_none(self.username)
+
class Password(AuthSource):
"""
@@ -58,9 +70,21 @@ class Password(AuthSource):
self.password_getter = password_getter
def __repr__(self):
+ # Password auth is marginally more 'username-caring' than pkeys, so may
+ # as well log that info here.
return super()._repr(user=self.username)
+ def authenticate(self, transport):
+ # Lazily get the password, in case it's prompting a user
+ # TODO: be nice to log source _of_ the password?
+ password = self.password_getter()
+ return transport.auth_password(self.username, password)
+
+# TODO 4.0: twiddle this, or PKey, or both, so they're more obviously distinct.
+# TODO 4.0: the obvious is to make this more wordy (PrivateKeyAuth), the
+# minimalist approach might be to rename PKey to just Key (esp given all the
+# subclasses are WhateverKey and not WhateverPKey)
class PrivateKey(AuthSource):
"""
Essentially a mixin for private keys.
@@ -74,6 +98,9 @@ class PrivateKey(AuthSource):
its `super` call.
"""
+ def authenticate(self, transport):
+ return transport.auth_publickey(self.username, self.pkey)
+
class InMemoryPrivateKey(PrivateKey):
"""
@@ -82,12 +109,15 @@ class InMemoryPrivateKey(PrivateKey):
def __init__(self, username, pkey):
super().__init__(username=username)
+ # No decryption (presumably) necessary!
self.pkey = pkey
def __repr__(self):
+ # NOTE: most of interesting repr-bits for private keys is in PKey.
+ # TODO: tacking on agent-ness like this is a bit awkward, but, eh?
rep = super()._repr(pkey=self.pkey)
if isinstance(self.pkey, AgentKey):
- rep += ' [agent]'
+ rep += " [agent]"
return rep
@@ -107,20 +137,34 @@ class OnDiskPrivateKey(PrivateKey):
def __init__(self, username, source, path, pkey):
super().__init__(username=username)
self.source = source
- allowed = 'ssh-config', 'python-config', 'implicit-home'
+ allowed = ("ssh-config", "python-config", "implicit-home")
if source not in allowed:
- raise ValueError(f'source argument must be one of: {allowed!r}')
+ raise ValueError(f"source argument must be one of: {allowed!r}")
self.path = path
+ # Superclass wants .pkey, other two are mostly for display/debugging.
self.pkey = pkey
def __repr__(self):
- return self._repr(key=self.pkey, source=self.source, path=str(self.
- path))
+ return self._repr(
+ key=self.pkey, source=self.source, path=str(self.path)
+ )
+
+# TODO re sources: is there anything in an OpenSSH config file that doesn't fit
+# into what Paramiko already had kwargs for?
-SourceResult = namedtuple('SourceResult', ['source', 'result'])
+SourceResult = namedtuple("SourceResult", ["source", "result"])
+# TODO: tempting to make this an OrderedDict, except the keys essentially want
+# to be rich objects (AuthSources) which do not make for useful user indexing?
+# TODO: members being vanilla tuples is pretty old-school/expedient; they
+# "really" want to be something that's type friendlier (unless the tuple's 2nd
+# member being a Union of two types is "fine"?), which I assume means yet more
+# classes, eg an abstract SourceResult with concrete AuthSuccess and
+# AuthFailure children?
+# TODO: arguably we want __init__ typechecking of the members (or to leverage
+# mypy by classifying this literally as list-of-AuthSource?)
class AuthResult(list):
"""
Represents a partial or complete SSH authentication attempt.
@@ -157,10 +201,16 @@ class AuthResult(list):
super().__init__(*args, **kwargs)
def __str__(self):
- return '\n'.join(f"{x.source} -> {x.result or 'success'}" for x in self
- )
+ # NOTE: meaningfully distinct from __repr__, which still wants to use
+ # superclass' implementation.
+ # TODO: go hog wild, use rich.Table? how is that on degraded term's?
+ # TODO: test this lol
+ return "\n".join(
+ f"{x.source} -> {x.result or 'success'}" for x in self
+ )
+# TODO 4.0: descend from SSHException or even just Exception
class AuthFailure(AuthenticationException):
"""
Basic exception wrapping an `AuthResult` indicating overall auth failure.
@@ -176,7 +226,7 @@ class AuthFailure(AuthenticationException):
self.result = result
def __str__(self):
- return '\n' + str(self.result)
+ return "\n" + str(self.result)
class AuthStrategy:
@@ -188,7 +238,10 @@ class AuthStrategy:
their particular strategy.
"""
- def __init__(self, ssh_config):
+ def __init__(
+ self,
+ ssh_config,
+ ):
self.ssh_config = ssh_config
self.log = get_logger(__name__)
@@ -202,7 +255,7 @@ class AuthStrategy:
Subclasses _of_ subclasses may find themselves wanting to do things
like filtering or discarding around a call to `super`.
"""
- pass
+ raise NotImplementedError
def authenticate(self, transport):
"""
@@ -211,4 +264,43 @@ class AuthStrategy:
You *normally* won't need to override this, but it's an option for
advanced users.
"""
- pass
+ succeeded = False
+ overall_result = AuthResult(strategy=self)
+ # TODO: arguably we could fit in a "send none auth, record allowed auth
+ # types sent back" thing here as OpenSSH-client does, but that likely
+ # wants to live in fabric.OpenSSHAuthStrategy as not all target servers
+ # will implement it!
+ # TODO: needs better "server told us too many attempts" checking!
+ for source in self.get_sources():
+ self.log.debug(f"Trying {source}")
+ try: # NOTE: this really wants to _only_ wrap the authenticate()!
+ result = source.authenticate(transport)
+ succeeded = True
+ # TODO: 'except PartialAuthentication' is needed for 2FA and
+ # similar, as per old SSHClient.connect - it is the only way
+ # AuthHandler supplies access to the 'name-list' field from
+ # MSG_USERAUTH_FAILURE, at present.
+ except Exception as e:
+ result = e
+ # TODO: look at what this could possibly raise, we don't really
+ # want Exception here, right? just SSHException subclasses? or
+ # do we truly want to capture anything at all with assumption
+ # it's easy enough for users to look afterwards?
+ # NOTE: showing type, not message, for tersity & also most of
+ # the time it's basically just "Authentication failed."
+ source_class = e.__class__.__name__
+ self.log.info(
+ f"Authentication via {source} failed with {source_class}"
+ )
+ overall_result.append(SourceResult(source, result))
+ if succeeded:
+ break
+ # Gotta die here if nothing worked, otherwise Transport's main loop
+ # just kinda hangs out until something times out!
+ if not succeeded:
+ raise AuthFailure(result=overall_result)
+ # Success: give back what was done, in case they care.
+ return overall_result
+
+ # TODO: is there anything OpenSSH client does which _can't_ cleanly map to
+ # iterating a generator?
diff --git a/paramiko/ber.py b/paramiko/ber.py
index 18f67749..b8287f5d 100644
--- a/paramiko/ber.py
+++ b/paramiko/ber.py
@@ -1,4 +1,22 @@
+# Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
from paramiko.common import max_byte, zero_byte, byte_ord, byte_chr
+
import paramiko.util as util
from paramiko.util import b
from paramiko.sftp import int64
@@ -17,8 +35,105 @@ class BER:
self.content = b(content)
self.idx = 0
+ def asbytes(self):
+ return self.content
+
def __str__(self):
return self.asbytes()
def __repr__(self):
return "BER('" + repr(self.content) + "')"
+
+ def decode(self):
+ return self.decode_next()
+
+ def decode_next(self):
+ if self.idx >= len(self.content):
+ return None
+ ident = byte_ord(self.content[self.idx])
+ self.idx += 1
+ if (ident & 31) == 31:
+ # identifier > 30
+ ident = 0
+ while self.idx < len(self.content):
+ t = byte_ord(self.content[self.idx])
+ self.idx += 1
+ ident = (ident << 7) | (t & 0x7F)
+ if not (t & 0x80):
+ break
+ if self.idx >= len(self.content):
+ return None
+ # now fetch length
+ size = byte_ord(self.content[self.idx])
+ self.idx += 1
+ if size & 0x80:
+ # more complimicated...
+ # FIXME: theoretically should handle indefinite-length (0x80)
+ t = size & 0x7F
+ if self.idx + t > len(self.content):
+ return None
+ size = util.inflate_long(
+ self.content[self.idx : self.idx + t], True
+ )
+ self.idx += t
+ if self.idx + size > len(self.content):
+ # can't fit
+ return None
+ data = self.content[self.idx : self.idx + size]
+ self.idx += size
+ # now switch on id
+ if ident == 0x30:
+ # sequence
+ return self.decode_sequence(data)
+ elif ident == 2:
+ # int
+ return util.inflate_long(data)
+ else:
+ # 1: boolean (00 false, otherwise true)
+ msg = "Unknown ber encoding type {:d} (robey is lazy)"
+ raise BERException(msg.format(ident))
+
+ @staticmethod
+ def decode_sequence(data):
+ out = []
+ ber = BER(data)
+ while True:
+ x = ber.decode_next()
+ if x is None:
+ break
+ out.append(x)
+ return out
+
+ def encode_tlv(self, ident, val):
+ # no need to support ident > 31 here
+ self.content += byte_chr(ident)
+ if len(val) > 0x7F:
+ lenstr = util.deflate_long(len(val))
+ self.content += byte_chr(0x80 + len(lenstr)) + lenstr
+ else:
+ self.content += byte_chr(len(val))
+ self.content += val
+
+ def encode(self, x):
+ if type(x) is bool:
+ if x:
+ self.encode_tlv(1, max_byte)
+ else:
+ self.encode_tlv(1, zero_byte)
+ elif (type(x) is int) or (type(x) is int64):
+ self.encode_tlv(2, util.deflate_long(x))
+ elif type(x) is str:
+ self.encode_tlv(4, x)
+ elif (type(x) is list) or (type(x) is tuple):
+ self.encode_tlv(0x30, self.encode_sequence(x))
+ else:
+ raise BERException(
+ "Unknown type for encoding: {!r}".format(type(x))
+ )
+
+ @staticmethod
+ def encode_sequence(data):
+ ber = BER()
+ for item in data:
+ ber.encode(item)
+ return ber.asbytes()
diff --git a/paramiko/buffered_pipe.py b/paramiko/buffered_pipe.py
index 0e56ca4d..c19279c0 100644
--- a/paramiko/buffered_pipe.py
+++ b/paramiko/buffered_pipe.py
@@ -1,8 +1,27 @@
+# Copyright (C) 2006-2007 Robey Pointer <robeypointer@gmail.com>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
"""
Attempt to generalize the "feeder" part of a `.Channel`: an object which can be
read from and closed, but is reading from a buffer fed by another thread. The
read operations are blocking and can have a timeout set.
"""
+
import array
import threading
import time
@@ -13,6 +32,7 @@ class PipeTimeout(IOError):
"""
Indicates that a timeout was reached on a read from a `.BufferedPipe`.
"""
+
pass
@@ -27,9 +47,15 @@ class BufferedPipe:
self._lock = threading.Lock()
self._cv = threading.Condition(self._lock)
self._event = None
- self._buffer = array.array('B')
+ self._buffer = array.array("B")
self._closed = False
+ def _buffer_frombytes(self, data):
+ self._buffer.frombytes(data)
+
+ def _buffer_tobytes(self, limit=None):
+ return self._buffer[:limit].tobytes()
+
def set_event(self, event):
"""
Set an event on this buffer. When data is ready to be read (or the
@@ -38,7 +64,20 @@ class BufferedPipe:
:param threading.Event event: the event to set/clear
"""
- pass
+ self._lock.acquire()
+ try:
+ self._event = event
+ # Make sure the event starts in `set` state if we appear to already
+ # be closed; otherwise, if we start in `clear` state & are closed,
+ # nothing will ever call `.feed` and the event (& OS pipe, if we're
+ # wrapping one - see `Channel.fileno`) will permanently stay in
+ # `clear`, causing deadlock if e.g. `select`ed upon.
+ if self._closed or len(self._buffer) > 0:
+ event.set()
+ else:
+ event.clear()
+ finally:
+ self._lock.release()
def feed(self, data):
"""
@@ -47,7 +86,14 @@ class BufferedPipe:
:param data: the data to add, as a ``str`` or ``bytes``
"""
- pass
+ self._lock.acquire()
+ try:
+ if self._event is not None:
+ self._event.set()
+ self._buffer_frombytes(b(data))
+ self._cv.notify_all()
+ finally:
+ self._lock.release()
def read_ready(self):
"""
@@ -59,7 +105,13 @@ class BufferedPipe:
``True`` if a `read` call would immediately return at least one
byte; ``False`` otherwise.
"""
- pass
+ self._lock.acquire()
+ try:
+ if len(self._buffer) == 0:
+ return False
+ return True
+ finally:
+ self._lock.release()
def read(self, nbytes, timeout=None):
"""
@@ -82,7 +134,38 @@ class BufferedPipe:
`.PipeTimeout` -- if a timeout was specified and no data was ready
before that timeout
"""
- pass
+ out = bytes()
+ self._lock.acquire()
+ try:
+ if len(self._buffer) == 0:
+ if self._closed:
+ return out
+ # should we block?
+ if timeout == 0.0:
+ raise PipeTimeout()
+ # loop here in case we get woken up but a different thread has
+ # grabbed everything in the buffer.
+ while (len(self._buffer) == 0) and not self._closed:
+ then = time.time()
+ self._cv.wait(timeout)
+ if timeout is not None:
+ timeout -= time.time() - then
+ if timeout <= 0.0:
+ raise PipeTimeout()
+
+ # something's in the buffer and we have the lock!
+ if len(self._buffer) <= nbytes:
+ out = self._buffer_tobytes()
+ del self._buffer[:]
+ if (self._event is not None) and not self._closed:
+ self._event.clear()
+ else:
+ out = self._buffer_tobytes(nbytes)
+ del self._buffer[:nbytes]
+ finally:
+ self._lock.release()
+
+ return out
def empty(self):
"""
@@ -92,14 +175,29 @@ class BufferedPipe:
any data that was in the buffer prior to clearing it out, as a
`str`
"""
- pass
+ self._lock.acquire()
+ try:
+ out = self._buffer_tobytes()
+ del self._buffer[:]
+ if (self._event is not None) and not self._closed:
+ self._event.clear()
+ return out
+ finally:
+ self._lock.release()
def close(self):
"""
Close this pipe object. Future calls to `read` after the buffer
has been emptied will return immediately with an empty string.
"""
- pass
+ self._lock.acquire()
+ try:
+ self._closed = True
+ self._cv.notify_all()
+ if self._event is not None:
+ self._event.set()
+ finally:
+ self._lock.release()
def __len__(self):
"""
diff --git a/paramiko/channel.py b/paramiko/channel.py
index 45548521..2757450b 100644
--- a/paramiko/channel.py
+++ b/paramiko/channel.py
@@ -1,14 +1,46 @@
+# Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
"""
Abstraction for an SSH2 channel.
"""
+
import binascii
import os
import socket
import time
import threading
+
from functools import wraps
+
from paramiko import util
-from paramiko.common import cMSG_CHANNEL_REQUEST, cMSG_CHANNEL_WINDOW_ADJUST, cMSG_CHANNEL_DATA, cMSG_CHANNEL_EXTENDED_DATA, DEBUG, ERROR, cMSG_CHANNEL_SUCCESS, cMSG_CHANNEL_FAILURE, cMSG_CHANNEL_EOF, cMSG_CHANNEL_CLOSE
+from paramiko.common import (
+ cMSG_CHANNEL_REQUEST,
+ cMSG_CHANNEL_WINDOW_ADJUST,
+ cMSG_CHANNEL_DATA,
+ cMSG_CHANNEL_EXTENDED_DATA,
+ DEBUG,
+ ERROR,
+ cMSG_CHANNEL_SUCCESS,
+ cMSG_CHANNEL_FAILURE,
+ cMSG_CHANNEL_EOF,
+ cMSG_CHANNEL_CLOSE,
+)
from paramiko.message import Message
from paramiko.ssh_exception import SSHException
from paramiko.file import BufferedFile
@@ -25,7 +57,19 @@ def open_only(func):
`.SSHException` -- If the wrapped method is called on an unopened
`.Channel`.
"""
- pass
+
+ @wraps(func)
+ def _check(self, *args, **kwds):
+ if (
+ self.closed
+ or self.eof_received
+ or self.eof_sent
+ or not self.active
+ ):
+ raise SSHException("Channel is not open")
+ return func(self, *args, **kwds)
+
+ return _check
class Channel(ClosingContextManager):
@@ -55,15 +99,20 @@ class Channel(ClosingContextManager):
:param int chanid:
the ID of this channel, as passed by an existing `.Transport`.
"""
+ #: Channel ID
self.chanid = chanid
+ #: Remote channel ID
self.remote_chanid = 0
+ #: `.Transport` managing this channel
self.transport = None
+ #: Whether the connection is presently active
self.active = False
self.eof_received = 0
self.eof_sent = 0
self.in_buffer = BufferedPipe()
self.in_stderr_buffer = BufferedPipe()
self.timeout = None
+ #: Whether the connection has been closed
self.closed = False
self.ultra_debug = False
self.lock = threading.Lock()
@@ -76,7 +125,7 @@ class Channel(ClosingContextManager):
self.in_window_sofar = 0
self.status_event = threading.Event()
self._name = str(chanid)
- self.logger = util.get_logger('paramiko.transport')
+ self.logger = util.get_logger("paramiko.transport")
self._pipe = None
self.event = threading.Event()
self.event_ready = False
@@ -94,24 +143,30 @@ class Channel(ClosingContextManager):
"""
Return a string representation of this object, for debugging.
"""
- out = '<paramiko.Channel {}'.format(self.chanid)
+ out = "<paramiko.Channel {}".format(self.chanid)
if self.closed:
- out += ' (closed)'
+ out += " (closed)"
elif self.active:
if self.eof_received:
- out += ' (EOF received)'
+ out += " (EOF received)"
if self.eof_sent:
- out += ' (EOF sent)'
- out += ' (open) window={}'.format(self.out_window_size)
+ out += " (EOF sent)"
+ out += " (open) window={}".format(self.out_window_size)
if len(self.in_buffer) > 0:
- out += ' in-buffer={}'.format(len(self.in_buffer))
- out += ' -> ' + repr(self.transport)
- out += '>'
+ out += " in-buffer={}".format(len(self.in_buffer))
+ out += " -> " + repr(self.transport)
+ out += ">"
return out
@open_only
- def get_pty(self, term='vt100', width=80, height=24, width_pixels=0,
- height_pixels=0):
+ def get_pty(
+ self,
+ term="vt100",
+ width=80,
+ height=24,
+ width_pixels=0,
+ height_pixels=0,
+ ):
"""
Request a pseudo-terminal from the server. This is usually used right
after creating a client channel, to ask the server to provide some
@@ -130,7 +185,20 @@ class Channel(ClosingContextManager):
`.SSHException` -- if the request was rejected or the channel was
closed
"""
- pass
+ m = Message()
+ m.add_byte(cMSG_CHANNEL_REQUEST)
+ m.add_int(self.remote_chanid)
+ m.add_string("pty-req")
+ m.add_boolean(True)
+ m.add_string(term)
+ m.add_int(width)
+ m.add_int(height)
+ m.add_int(width_pixels)
+ m.add_int(height_pixels)
+ m.add_string(bytes())
+ self._event_pending()
+ self.transport._send_user_message(m)
+ self._wait_for_event()
@open_only
def invoke_shell(self):
@@ -150,7 +218,14 @@ class Channel(ClosingContextManager):
`.SSHException` -- if the request was rejected or the channel was
closed
"""
- pass
+ m = Message()
+ m.add_byte(cMSG_CHANNEL_REQUEST)
+ m.add_int(self.remote_chanid)
+ m.add_string("shell")
+ m.add_boolean(True)
+ self._event_pending()
+ self.transport._send_user_message(m)
+ self._wait_for_event()
@open_only
def exec_command(self, command):
@@ -169,7 +244,15 @@ class Channel(ClosingContextManager):
`.SSHException` -- if the request was rejected or the channel was
closed
"""
- pass
+ m = Message()
+ m.add_byte(cMSG_CHANNEL_REQUEST)
+ m.add_int(self.remote_chanid)
+ m.add_string("exec")
+ m.add_boolean(True)
+ m.add_string(command)
+ self._event_pending()
+ self.transport._send_user_message(m)
+ self._wait_for_event()
@open_only
def invoke_subsystem(self, subsystem):
@@ -187,7 +270,15 @@ class Channel(ClosingContextManager):
`.SSHException` -- if the request was rejected or the channel was
closed
"""
- pass
+ m = Message()
+ m.add_byte(cMSG_CHANNEL_REQUEST)
+ m.add_int(self.remote_chanid)
+ m.add_string("subsystem")
+ m.add_boolean(True)
+ m.add_string(subsystem)
+ self._event_pending()
+ self.transport._send_user_message(m)
+ self._wait_for_event()
@open_only
def resize_pty(self, width=80, height=24, width_pixels=0, height_pixels=0):
@@ -204,7 +295,16 @@ class Channel(ClosingContextManager):
`.SSHException` -- if the request was rejected or the channel was
closed
"""
- pass
+ m = Message()
+ m.add_byte(cMSG_CHANNEL_REQUEST)
+ m.add_int(self.remote_chanid)
+ m.add_string("window-change")
+ m.add_boolean(False)
+ m.add_int(width)
+ m.add_int(height)
+ m.add_int(width_pixels)
+ m.add_int(height_pixels)
+ self.transport._send_user_message(m)
@open_only
def update_environment(self, environment):
@@ -225,7 +325,12 @@ class Channel(ClosingContextManager):
`.SSHException` -- if any of the environment variables was rejected
by the server or the channel was closed
"""
- pass
+ for name, value in environment.items():
+ try:
+ self.set_environment_variable(name, value)
+ except SSHException as e:
+ err = 'Failed to set environment variable "{}".'
+ raise SSHException(err.format(name), e)
@open_only
def set_environment_variable(self, name, value):
@@ -245,7 +350,14 @@ class Channel(ClosingContextManager):
`.SSHException` -- if the request was rejected or the channel was
closed
"""
- pass
+ m = Message()
+ m.add_byte(cMSG_CHANNEL_REQUEST)
+ m.add_int(self.remote_chanid)
+ m.add_string("env")
+ m.add_boolean(False)
+ m.add_string(name)
+ m.add_string(value)
+ self.transport._send_user_message(m)
def exit_status_ready(self):
"""
@@ -260,7 +372,7 @@ class Channel(ClosingContextManager):
.. versionadded:: 1.7.3
"""
- pass
+ return self.closed or self.status_event.is_set()
def recv_exit_status(self):
"""
@@ -285,7 +397,9 @@ class Channel(ClosingContextManager):
.. versionadded:: 1.2
"""
- pass
+ self.status_event.wait()
+ assert self.status_event.is_set()
+ return self.exit_status
def send_exit_status(self, status):
"""
@@ -298,11 +412,25 @@ class Channel(ClosingContextManager):
.. versionadded:: 1.2
"""
- pass
+ # in many cases, the channel will not still be open here.
+ # that's fine.
+ m = Message()
+ m.add_byte(cMSG_CHANNEL_REQUEST)
+ m.add_int(self.remote_chanid)
+ m.add_string("exit-status")
+ m.add_boolean(False)
+ m.add_int(status)
+ self.transport._send_user_message(m)
@open_only
- def request_x11(self, screen_number=0, auth_protocol=None, auth_cookie=
- None, single_connection=False, handler=None):
+ def request_x11(
+ self,
+ screen_number=0,
+ auth_protocol=None,
+ auth_cookie=None,
+ single_connection=False,
+ handler=None,
+ ):
"""
Request an x11 session on this channel. If the server allows it,
further x11 requests can be made from the server to the client,
@@ -341,7 +469,25 @@ class Channel(ClosingContextManager):
an optional callable handler to use for incoming X11 connections
:return: the auth_cookie used
"""
- pass
+ if auth_protocol is None:
+ auth_protocol = "MIT-MAGIC-COOKIE-1"
+ if auth_cookie is None:
+ auth_cookie = binascii.hexlify(os.urandom(16))
+
+ m = Message()
+ m.add_byte(cMSG_CHANNEL_REQUEST)
+ m.add_int(self.remote_chanid)
+ m.add_string("x11-req")
+ m.add_boolean(True)
+ m.add_boolean(single_connection)
+ m.add_string(auth_protocol)
+ m.add_string(auth_cookie)
+ m.add_int(screen_number)
+ self._event_pending()
+ self.transport._send_user_message(m)
+ self._wait_for_event()
+ self.transport._set_x11_handler(handler)
+ return auth_cookie
@open_only
def request_forward_agent(self, handler):
@@ -358,13 +504,20 @@ class Channel(ClosingContextManager):
:raises: SSHException in case of channel problem.
"""
- pass
+ m = Message()
+ m.add_byte(cMSG_CHANNEL_REQUEST)
+ m.add_int(self.remote_chanid)
+ m.add_string("auth-agent-req@openssh.com")
+ m.add_boolean(False)
+ self.transport._send_user_message(m)
+ self.transport._set_forward_agent_handler(handler)
+ return True
def get_transport(self):
"""
Return the `.Transport` associated with this channel.
"""
- pass
+ return self.transport
def set_name(self, name):
"""
@@ -374,13 +527,13 @@ class Channel(ClosingContextManager):
:param str name: new channel name
"""
- pass
+ self._name = name
def get_name(self):
"""
Get the name of this channel that was previously set by `set_name`.
"""
- pass
+ return self._name
def get_id(self):
"""
@@ -391,7 +544,7 @@ class Channel(ClosingContextManager):
`.ServerInterface.check_channel_request` when determining whether to
accept a channel request in server mode.
"""
- pass
+ return self.chanid
def set_combine_stderr(self, combine):
"""
@@ -414,7 +567,21 @@ class Channel(ClosingContextManager):
.. versionadded:: 1.1
"""
- pass
+ data = bytes()
+ self.lock.acquire()
+ try:
+ old = self.combine_stderr
+ self.combine_stderr = combine
+ if combine and not old:
+ # copy old stderr buffer into primary buffer
+ data = self.in_stderr_buffer.empty()
+ finally:
+ self.lock.release()
+ if len(data) > 0:
+ self._feed(data)
+ return old
+
+ # ...socket API...
def settimeout(self, timeout):
"""
@@ -432,7 +599,7 @@ class Channel(ClosingContextManager):
seconds to wait for a pending read/write operation before raising
``socket.timeout``, or ``None`` for no timeout.
"""
- pass
+ self.timeout = timeout
def gettimeout(self):
"""
@@ -440,7 +607,7 @@ class Channel(ClosingContextManager):
operations, or ``None`` if no timeout is set. This reflects the last
call to `setblocking` or `settimeout`.
"""
- pass
+ return self.timeout
def setblocking(self, blocking):
"""
@@ -460,7 +627,10 @@ class Channel(ClosingContextManager):
:param int blocking:
0 to set non-blocking mode; non-0 to set blocking mode.
"""
- pass
+ if blocking:
+ self.settimeout(None)
+ else:
+ self.settimeout(0.0)
def getpeername(self):
"""
@@ -470,7 +640,7 @@ class Channel(ClosingContextManager):
socket-like interface to allow asyncore to work. (asyncore likes to
call ``'getpeername'``.)
"""
- pass
+ return self.transport.getpeername()
def close(self):
"""
@@ -479,7 +649,24 @@ class Channel(ClosingContextManager):
is flushed). Channels are automatically closed when their `.Transport`
is closed or when they are garbage collected.
"""
- pass
+ self.lock.acquire()
+ try:
+ # only close the pipe when the user explicitly closes the channel.
+ # otherwise they will get unpleasant surprises. (and do it before
+ # checking self.closed, since the remote host may have already
+ # closed the connection.)
+ if self._pipe is not None:
+ self._pipe.close()
+ self._pipe = None
+
+ if not self.active or self.closed:
+ return
+ msgs = self._close_internal()
+ finally:
+ self.lock.release()
+ for m in msgs:
+ if m is not None:
+ self.transport._send_user_message(m)
def recv_ready(self):
"""
@@ -491,7 +678,7 @@ class Channel(ClosingContextManager):
``True`` if a `recv` call on this channel would immediately return
at least one byte; ``False`` otherwise.
"""
- pass
+ return self.in_buffer.read_ready()
def recv(self, nbytes):
"""
@@ -506,7 +693,21 @@ class Channel(ClosingContextManager):
:raises socket.timeout:
if no data is ready before the timeout set by `settimeout`.
"""
- pass
+ try:
+ out = self.in_buffer.read(nbytes, self.timeout)
+ except PipeTimeout:
+ raise socket.timeout()
+
+ ack = self._check_add_window(len(out))
+ # no need to hold the channel lock when sending this
+ if ack > 0:
+ m = Message()
+ m.add_byte(cMSG_CHANNEL_WINDOW_ADJUST)
+ m.add_int(self.remote_chanid)
+ m.add_int(ack)
+ self.transport._send_user_message(m)
+
+ return out
def recv_stderr_ready(self):
"""
@@ -521,7 +722,7 @@ class Channel(ClosingContextManager):
.. versionadded:: 1.1
"""
- pass
+ return self.in_stderr_buffer.read_ready()
def recv_stderr(self, nbytes):
"""
@@ -540,7 +741,21 @@ class Channel(ClosingContextManager):
.. versionadded:: 1.1
"""
- pass
+ try:
+ out = self.in_stderr_buffer.read(nbytes, self.timeout)
+ except PipeTimeout:
+ raise socket.timeout()
+
+ ack = self._check_add_window(len(out))
+ # no need to hold the channel lock when sending this
+ if ack > 0:
+ m = Message()
+ m.add_byte(cMSG_CHANNEL_WINDOW_ADJUST)
+ m.add_int(self.remote_chanid)
+ m.add_int(ack)
+ self.transport._send_user_message(m)
+
+ return out
def send_ready(self):
"""
@@ -555,7 +770,13 @@ class Channel(ClosingContextManager):
``True`` if a `send` call on this channel would immediately succeed
or fail
"""
- pass
+ self.lock.acquire()
+ try:
+ if self.closed or self.eof_sent:
+ return True
+ return self.out_window_size > 0
+ finally:
+ self.lock.release()
def send(self, s):
"""
@@ -571,7 +792,11 @@ class Channel(ClosingContextManager):
:raises socket.timeout: if no data could be sent before the timeout set
by `settimeout`.
"""
- pass
+
+ m = Message()
+ m.add_byte(cMSG_CHANNEL_DATA)
+ m.add_int(self.remote_chanid)
+ return self._send(s, m)
def send_stderr(self, s):
"""
@@ -590,7 +815,12 @@ class Channel(ClosingContextManager):
.. versionadded:: 1.1
"""
- pass
+
+ m = Message()
+ m.add_byte(cMSG_CHANNEL_EXTENDED_DATA)
+ m.add_int(self.remote_chanid)
+ m.add_int(1)
+ return self._send(s, m)
def sendall(self, s):
"""
@@ -610,7 +840,10 @@ class Channel(ClosingContextManager):
sent, there is no way to determine how much data (if any) was sent.
This is irritating, but identically follows Python's API.
"""
- pass
+ while s:
+ sent = self.send(s)
+ s = s[sent:]
+ return None
def sendall_stderr(self, s):
"""
@@ -628,7 +861,10 @@ class Channel(ClosingContextManager):
.. versionadded:: 1.1
"""
- pass
+ while s:
+ sent = self.send_stderr(s)
+ s = s[sent:]
+ return None
def makefile(self, *params):
"""
@@ -638,7 +874,7 @@ class Channel(ClosingContextManager):
:return: `.ChannelFile` object which can be used for Python file I/O.
"""
- pass
+ return ChannelFile(*([self] + list(params)))
def makefile_stderr(self, *params):
"""
@@ -656,7 +892,7 @@ class Channel(ClosingContextManager):
.. versionadded:: 1.1
"""
- pass
+ return ChannelStderrFile(*([self] + list(params)))
def makefile_stdin(self, *params):
"""
@@ -673,7 +909,7 @@ class Channel(ClosingContextManager):
.. versionadded:: 2.6
"""
- pass
+ return ChannelStdinFile(*([self] + list(params)))
def fileno(self):
"""
@@ -692,7 +928,18 @@ class Channel(ClosingContextManager):
.. warning::
This method causes channel reads to be slightly less efficient.
"""
- pass
+ self.lock.acquire()
+ try:
+ if self._pipe is not None:
+ return self._pipe.fileno()
+ # create the pipe and feed in any existing data
+ self._pipe = pipe.make_pipe()
+ p1, p2 = pipe.make_or_pipe(self._pipe)
+ self.in_buffer.set_event(p1)
+ self.in_stderr_buffer.set_event(p2)
+ return self._pipe.fileno()
+ finally:
+ self.lock.release()
def shutdown(self, how):
"""
@@ -705,7 +952,17 @@ class Channel(ClosingContextManager):
0 (stop receiving), 1 (stop sending), or 2 (stop receiving and
sending).
"""
- pass
+ if (how == 0) or (how == 2):
+ # feign "read" shutdown
+ self.eof_received = 1
+ if (how == 1) or (how == 2):
+ self.lock.acquire()
+ try:
+ m = self._send_eof()
+ finally:
+ self.lock.release()
+ if m is not None:
+ self.transport._send_user_message(m)
def shutdown_read(self):
"""
@@ -717,7 +974,7 @@ class Channel(ClosingContextManager):
.. versionadded:: 1.2
"""
- pass
+ self.shutdown(0)
def shutdown_write(self):
"""
@@ -729,7 +986,310 @@ class Channel(ClosingContextManager):
.. versionadded:: 1.2
"""
- pass
+ self.shutdown(1)
+
+ @property
+ def _closed(self):
+ # Concession to Python 3's socket API, which has a private ._closed
+ # attribute instead of a semipublic .closed attribute.
+ return self.closed
+
+ # ...calls from Transport
+
+ def _set_transport(self, transport):
+ self.transport = transport
+ self.logger = util.get_logger(self.transport.get_log_channel())
+
+ def _set_window(self, window_size, max_packet_size):
+ self.in_window_size = window_size
+ self.in_max_packet_size = max_packet_size
+ # threshold of bytes we receive before we bother to send
+ # a window update
+ self.in_window_threshold = window_size // 10
+ self.in_window_sofar = 0
+ self._log(DEBUG, "Max packet in: {} bytes".format(max_packet_size))
+
+ def _set_remote_channel(self, chanid, window_size, max_packet_size):
+ self.remote_chanid = chanid
+ self.out_window_size = window_size
+ self.out_max_packet_size = self.transport._sanitize_packet_size(
+ max_packet_size
+ )
+ self.active = 1
+ self._log(
+ DEBUG, "Max packet out: {} bytes".format(self.out_max_packet_size)
+ )
+
+ def _request_success(self, m):
+ self._log(DEBUG, "Sesch channel {} request ok".format(self.chanid))
+ self.event_ready = True
+ self.event.set()
+ return
+
+ def _request_failed(self, m):
+ self.lock.acquire()
+ try:
+ msgs = self._close_internal()
+ finally:
+ self.lock.release()
+ for m in msgs:
+ if m is not None:
+ self.transport._send_user_message(m)
+
+ def _feed(self, m):
+ if isinstance(m, bytes):
+ # passed from _feed_extended
+ s = m
+ else:
+ s = m.get_binary()
+ self.in_buffer.feed(s)
+
+ def _feed_extended(self, m):
+ code = m.get_int()
+ s = m.get_binary()
+ if code != 1:
+ self._log(
+ ERROR, "unknown extended_data type {}; discarding".format(code)
+ )
+ return
+ if self.combine_stderr:
+ self._feed(s)
+ else:
+ self.in_stderr_buffer.feed(s)
+
+ def _window_adjust(self, m):
+ nbytes = m.get_int()
+ self.lock.acquire()
+ try:
+ if self.ultra_debug:
+ self._log(DEBUG, "window up {}".format(nbytes))
+ self.out_window_size += nbytes
+ self.out_buffer_cv.notify_all()
+ finally:
+ self.lock.release()
+
+ def _handle_request(self, m):
+ key = m.get_text()
+ want_reply = m.get_boolean()
+ server = self.transport.server_object
+ ok = False
+ if key == "exit-status":
+ self.exit_status = m.get_int()
+ self.status_event.set()
+ ok = True
+ elif key == "xon-xoff":
+ # ignore
+ ok = True
+ elif key == "pty-req":
+ term = m.get_string()
+ width = m.get_int()
+ height = m.get_int()
+ pixelwidth = m.get_int()
+ pixelheight = m.get_int()
+ modes = m.get_string()
+ if server is None:
+ ok = False
+ else:
+ ok = server.check_channel_pty_request(
+ self, term, width, height, pixelwidth, pixelheight, modes
+ )
+ elif key == "shell":
+ if server is None:
+ ok = False
+ else:
+ ok = server.check_channel_shell_request(self)
+ elif key == "env":
+ name = m.get_string()
+ value = m.get_string()
+ if server is None:
+ ok = False
+ else:
+ ok = server.check_channel_env_request(self, name, value)
+ elif key == "exec":
+ cmd = m.get_string()
+ if server is None:
+ ok = False
+ else:
+ ok = server.check_channel_exec_request(self, cmd)
+ elif key == "subsystem":
+ name = m.get_text()
+ if server is None:
+ ok = False
+ else:
+ ok = server.check_channel_subsystem_request(self, name)
+ elif key == "window-change":
+ width = m.get_int()
+ height = m.get_int()
+ pixelwidth = m.get_int()
+ pixelheight = m.get_int()
+ if server is None:
+ ok = False
+ else:
+ ok = server.check_channel_window_change_request(
+ self, width, height, pixelwidth, pixelheight
+ )
+ elif key == "x11-req":
+ single_connection = m.get_boolean()
+ auth_proto = m.get_text()
+ auth_cookie = m.get_binary()
+ screen_number = m.get_int()
+ if server is None:
+ ok = False
+ else:
+ ok = server.check_channel_x11_request(
+ self,
+ single_connection,
+ auth_proto,
+ auth_cookie,
+ screen_number,
+ )
+ elif key == "auth-agent-req@openssh.com":
+ if server is None:
+ ok = False
+ else:
+ ok = server.check_channel_forward_agent_request(self)
+ else:
+ self._log(DEBUG, 'Unhandled channel request "{}"'.format(key))
+ ok = False
+ if want_reply:
+ m = Message()
+ if ok:
+ m.add_byte(cMSG_CHANNEL_SUCCESS)
+ else:
+ m.add_byte(cMSG_CHANNEL_FAILURE)
+ m.add_int(self.remote_chanid)
+ self.transport._send_user_message(m)
+
+ def _handle_eof(self, m):
+ self.lock.acquire()
+ try:
+ if not self.eof_received:
+ self.eof_received = True
+ self.in_buffer.close()
+ self.in_stderr_buffer.close()
+ if self._pipe is not None:
+ self._pipe.set_forever()
+ finally:
+ self.lock.release()
+ self._log(DEBUG, "EOF received ({})".format(self._name))
+
+ def _handle_close(self, m):
+ self.lock.acquire()
+ try:
+ msgs = self._close_internal()
+ self.transport._unlink_channel(self.chanid)
+ finally:
+ self.lock.release()
+ for m in msgs:
+ if m is not None:
+ self.transport._send_user_message(m)
+
+ # ...internals...
+
+ def _send(self, s, m):
+ size = len(s)
+ self.lock.acquire()
+ try:
+ if self.closed:
+ # this doesn't seem useful, but it is the documented behavior
+ # of Socket
+ raise socket.error("Socket is closed")
+ size = self._wait_for_send_window(size)
+ if size == 0:
+ # eof or similar
+ return 0
+ m.add_string(s[:size])
+ finally:
+ self.lock.release()
+ # Note: We release self.lock before calling _send_user_message.
+ # Otherwise, we can deadlock during re-keying.
+ self.transport._send_user_message(m)
+ return size
+
+ def _log(self, level, msg, *args):
+ self.logger.log(level, "[chan " + self._name + "] " + msg, *args)
+
+ def _event_pending(self):
+ self.event.clear()
+ self.event_ready = False
+
+ def _wait_for_event(self):
+ self.event.wait()
+ assert self.event.is_set()
+ if self.event_ready:
+ return
+ e = self.transport.get_exception()
+ if e is None:
+ e = SSHException("Channel closed.")
+ raise e
+
+ def _set_closed(self):
+ # you are holding the lock.
+ self.closed = True
+ self.in_buffer.close()
+ self.in_stderr_buffer.close()
+ self.out_buffer_cv.notify_all()
+ # Notify any waiters that we are closed
+ self.event.set()
+ self.status_event.set()
+ if self._pipe is not None:
+ self._pipe.set_forever()
+
+ def _send_eof(self):
+ # you are holding the lock.
+ if self.eof_sent:
+ return None
+ m = Message()
+ m.add_byte(cMSG_CHANNEL_EOF)
+ m.add_int(self.remote_chanid)
+ self.eof_sent = True
+ self._log(DEBUG, "EOF sent ({})".format(self._name))
+ return m
+
+ def _close_internal(self):
+ # you are holding the lock.
+ if not self.active or self.closed:
+ return None, None
+ m1 = self._send_eof()
+ m2 = Message()
+ m2.add_byte(cMSG_CHANNEL_CLOSE)
+ m2.add_int(self.remote_chanid)
+ self._set_closed()
+ # can't unlink from the Transport yet -- the remote side may still
+ # try to send meta-data (exit-status, etc)
+ return m1, m2
+
+ def _unlink(self):
+ # server connection could die before we become active:
+ # still signal the close!
+ if self.closed:
+ return
+ self.lock.acquire()
+ try:
+ self._set_closed()
+ self.transport._unlink_channel(self.chanid)
+ finally:
+ self.lock.release()
+
+ def _check_add_window(self, n):
+ self.lock.acquire()
+ try:
+ if self.closed or self.eof_received or not self.active:
+ return 0
+ if self.ultra_debug:
+ self._log(DEBUG, "addwindow {}".format(n))
+ self.in_window_sofar += n
+ if self.in_window_sofar <= self.in_window_threshold:
+ return 0
+ if self.ultra_debug:
+ self._log(
+ DEBUG, "addwindow send {}".format(self.in_window_sofar)
+ )
+ out = self.in_window_sofar
+ self.in_window_sofar = 0
+ return out
+ finally:
+ self.lock.release()
def _wait_for_send_window(self, size):
"""
@@ -739,7 +1299,36 @@ class Channel(ClosingContextManager):
exception is raised. Returns the number of bytes available to send
(may be less than requested).
"""
- pass
+ # you are already holding the lock
+ if self.closed or self.eof_sent:
+ return 0
+ if self.out_window_size == 0:
+ # should we block?
+ if self.timeout == 0.0:
+ raise socket.timeout()
+ # loop here in case we get woken up but a different thread has
+ # filled the buffer
+ timeout = self.timeout
+ while self.out_window_size == 0:
+ if self.closed or self.eof_sent:
+ return 0
+ then = time.time()
+ self.out_buffer_cv.wait(timeout)
+ if timeout is not None:
+ timeout -= time.time() - then
+ if timeout <= 0.0:
+ raise socket.timeout()
+ # we have some window to squeeze into
+ if self.closed or self.eof_sent:
+ return 0
+ if self.out_window_size < size:
+ size = self.out_window_size
+ if self.out_max_packet_size - 64 < size:
+ size = self.out_max_packet_size - 64
+ self.out_window_size -= size
+ if self.ultra_debug:
+ self._log(DEBUG, "window down to {}".format(self.out_window_size))
+ return size
class ChannelFile(BufferedFile):
@@ -755,7 +1344,7 @@ class ChannelFile(BufferedFile):
flush the buffer.
"""
- def __init__(self, channel, mode='r', bufsize=-1):
+ def __init__(self, channel, mode="r", bufsize=-1):
self.channel = channel
BufferedFile.__init__(self)
self._set_mode(mode, bufsize)
@@ -764,7 +1353,14 @@ class ChannelFile(BufferedFile):
"""
Returns a string representation of this object, for debugging.
"""
- return '<paramiko.ChannelFile from ' + repr(self.channel) + '>'
+ return "<paramiko.ChannelFile from " + repr(self.channel) + ">"
+
+ def _read(self, size):
+ return self.channel.recv(size)
+
+ def _write(self, data):
+ self.channel.sendall(data)
+ return len(data)
class ChannelStderrFile(ChannelFile):
@@ -774,6 +1370,13 @@ class ChannelStderrFile(ChannelFile):
See `Channel.makefile_stderr` for details.
"""
+ def _read(self, size):
+ return self.channel.recv_stderr(size)
+
+ def _write(self, data):
+ self.channel.sendall_stderr(data)
+ return len(data)
+
class ChannelStdinFile(ChannelFile):
"""
@@ -781,3 +1384,7 @@ class ChannelStdinFile(ChannelFile):
See `Channel.makefile_stdin` for details.
"""
+
+ def close(self):
+ super().close()
+ self.channel.shutdown_write()
diff --git a/paramiko/client.py b/paramiko/client.py
index a04a5244..d8be9108 100644
--- a/paramiko/client.py
+++ b/paramiko/client.py
@@ -1,6 +1,25 @@
+# Copyright (C) 2006-2007 Robey Pointer <robeypointer@gmail.com>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
"""
SSH client & key policies
"""
+
from binascii import hexlify
import getpass
import inspect
@@ -8,6 +27,7 @@ import os
import socket
import warnings
from errno import ECONNREFUSED, EHOSTUNREACH
+
from paramiko.agent import Agent
from paramiko.common import DEBUG
from paramiko.config import SSH_PORT
@@ -16,7 +36,11 @@ from paramiko.ecdsakey import ECDSAKey
from paramiko.ed25519key import Ed25519Key
from paramiko.hostkeys import HostKeys
from paramiko.rsakey import RSAKey
-from paramiko.ssh_exception import SSHException, BadHostKeyException, NoValidConnectionsError
+from paramiko.ssh_exception import (
+ SSHException,
+ BadHostKeyException,
+ NoValidConnectionsError,
+)
from paramiko.transport import Transport
from paramiko.util import ClosingContextManager
@@ -72,7 +96,15 @@ class SSHClient(ClosingContextManager):
:raises: ``IOError`` --
if a filename was provided and the file could not be read
"""
- pass
+ if filename is None:
+ # try the user's .ssh key file, and mask exceptions
+ filename = os.path.expanduser("~/.ssh/known_hosts")
+ try:
+ self._system_host_keys.load(filename)
+ except IOError:
+ pass
+ return
+ self._system_host_keys.load(filename)
def load_host_keys(self, filename):
"""
@@ -90,7 +122,8 @@ class SSHClient(ClosingContextManager):
:raises: ``IOError`` -- if the filename could not be read
"""
- pass
+ self._host_keys_filename = filename
+ self._host_keys.load(filename)
def save_host_keys(self, filename):
"""
@@ -102,7 +135,20 @@ class SSHClient(ClosingContextManager):
:raises: ``IOError`` -- if the file could not be written
"""
- pass
+
+ # update local host keys from file (in case other SSH clients
+ # have written to the known_hosts file meanwhile.
+ if self._host_keys_filename is not None:
+ self.load_host_keys(self._host_keys_filename)
+
+ with open(filename, "w") as f:
+ for hostname, keys in self._host_keys.items():
+ for keytype, key in keys.items():
+ f.write(
+ "{} {} {}\n".format(
+ hostname, keytype, key.get_base64()
+ )
+ )
def get_host_keys(self):
"""
@@ -111,7 +157,7 @@ class SSHClient(ClosingContextManager):
:return: the local host keys as a `.HostKeys` object.
"""
- pass
+ return self._host_keys
def set_log_channel(self, name):
"""
@@ -120,7 +166,7 @@ class SSHClient(ClosingContextManager):
:param str name: new channel name for logging
"""
- pass
+ self._log_channel = name
def set_missing_host_key_policy(self, policy):
"""
@@ -140,7 +186,9 @@ class SSHClient(ClosingContextManager):
the policy to use when receiving a host key from a
previously-unknown server
"""
- pass
+ if inspect.isclass(policy):
+ policy = policy()
+ self._policy = policy
def _families_and_addresses(self, hostname, port):
"""
@@ -150,15 +198,48 @@ class SSHClient(ClosingContextManager):
:param int port: the server port to connect to
:returns: Yields an iterable of ``(family, address)`` tuples
"""
- pass
-
- def connect(self, hostname, port=SSH_PORT, username=None, password=None,
- pkey=None, key_filename=None, timeout=None, allow_agent=True,
- look_for_keys=True, compress=False, sock=None, gss_auth=False,
- gss_kex=False, gss_deleg_creds=True, gss_host=None, banner_timeout=
- None, auth_timeout=None, channel_timeout=None, gss_trust_dns=True,
- passphrase=None, disabled_algorithms=None, transport_factory=None,
- auth_strategy=None):
+ guess = True
+ addrinfos = socket.getaddrinfo(
+ hostname, port, socket.AF_UNSPEC, socket.SOCK_STREAM
+ )
+ for (family, socktype, proto, canonname, sockaddr) in addrinfos:
+ if socktype == socket.SOCK_STREAM:
+ yield family, sockaddr
+ guess = False
+
+ # some OS like AIX don't indicate SOCK_STREAM support, so just
+ # guess. :( We only do this if we did not get a single result marked
+ # as socktype == SOCK_STREAM.
+ if guess:
+ for family, _, _, _, sockaddr in addrinfos:
+ yield family, sockaddr
+
+ def connect(
+ self,
+ hostname,
+ port=SSH_PORT,
+ username=None,
+ password=None,
+ pkey=None,
+ key_filename=None,
+ timeout=None,
+ allow_agent=True,
+ look_for_keys=True,
+ compress=False,
+ sock=None,
+ gss_auth=False,
+ gss_kex=False,
+ gss_deleg_creds=True,
+ gss_host=None,
+ banner_timeout=None,
+ auth_timeout=None,
+ channel_timeout=None,
+ gss_trust_dns=True,
+ passphrase=None,
+ disabled_algorithms=None,
+ transport_factory=None,
+ auth_strategy=None,
+ ):
"""
Connect to an SSH server and authenticate to it. The server's host key
is checked against the system host keys (see `load_system_host_keys`)
@@ -290,7 +371,130 @@ class SSHClient(ClosingContextManager):
.. versionchanged:: 3.2
Added the ``auth_strategy`` argument.
"""
- pass
+ if not sock:
+ errors = {}
+ # Try multiple possible address families (e.g. IPv4 vs IPv6)
+ to_try = list(self._families_and_addresses(hostname, port))
+ for af, addr in to_try:
+ try:
+ sock = socket.socket(af, socket.SOCK_STREAM)
+ if timeout is not None:
+ try:
+ sock.settimeout(timeout)
+ except:
+ pass
+ sock.connect(addr)
+ # Break out of the loop on success
+ break
+ except socket.error as e:
+ # As mentioned in socket docs it is better
+ # to close sockets explicitly
+ if sock:
+ sock.close()
+ # Raise anything that isn't a straight up connection error
+ # (such as a resolution error)
+ if e.errno not in (ECONNREFUSED, EHOSTUNREACH):
+ raise
+ # Capture anything else so we know how the run looks once
+ # iteration is complete. Retain info about which attempt
+ # this was.
+ errors[addr] = e
+
+ # Make sure we explode usefully if no address family attempts
+ # succeeded. We've no way of knowing which error is the "right"
+ # one, so we construct a hybrid exception containing all the real
+ # ones, of a subclass that client code should still be watching for
+ # (socket.error)
+ if len(errors) == len(to_try):
+ raise NoValidConnectionsError(errors)
+
+ if transport_factory is None:
+ transport_factory = Transport
+ t = self._transport = transport_factory(
+ sock,
+ gss_kex=gss_kex,
+ gss_deleg_creds=gss_deleg_creds,
+ disabled_algorithms=disabled_algorithms,
+ )
+ t.use_compression(compress=compress)
+ t.set_gss_host(
+ # t.hostname may be None, but GSS-API requires a target name.
+ # Therefore use hostname as fallback.
+ gss_host=gss_host or hostname,
+ trust_dns=gss_trust_dns,
+ gssapi_requested=gss_auth or gss_kex,
+ )
+ if self._log_channel is not None:
+ t.set_log_channel(self._log_channel)
+ if banner_timeout is not None:
+ t.banner_timeout = banner_timeout
+ if auth_timeout is not None:
+ t.auth_timeout = auth_timeout
+ if channel_timeout is not None:
+ t.channel_timeout = channel_timeout
+
+ if port == SSH_PORT:
+ server_hostkey_name = hostname
+ else:
+ server_hostkey_name = "[{}]:{}".format(hostname, port)
+ our_server_keys = None
+
+ our_server_keys = self._system_host_keys.get(server_hostkey_name)
+ if our_server_keys is None:
+ our_server_keys = self._host_keys.get(server_hostkey_name)
+ if our_server_keys is not None:
+ keytype = our_server_keys.keys()[0]
+ sec_opts = t.get_security_options()
+ other_types = [x for x in sec_opts.key_types if x != keytype]
+ sec_opts.key_types = [keytype] + other_types
+
+ t.start_client(timeout=timeout)
+
+ # If GSS-API Key Exchange is performed we are not required to check the
+ # host key, because the host is authenticated via GSS-API / SSPI as
+ # well as our client.
+ if not self._transport.gss_kex_used:
+ server_key = t.get_remote_server_key()
+ if our_server_keys is None:
+ # will raise exception if the key is rejected
+ self._policy.missing_host_key(
+ self, server_hostkey_name, server_key
+ )
+ else:
+ our_key = our_server_keys.get(server_key.get_name())
+ if our_key != server_key:
+ if our_key is None:
+ our_key = list(our_server_keys.values())[0]
+ raise BadHostKeyException(hostname, server_key, our_key)
+
+ if username is None:
+ username = getpass.getuser()
+
+ # New auth flow!
+ if auth_strategy is not None:
+ return auth_strategy.authenticate(transport=t)
+
+ # Old auth flow!
+ if key_filename is None:
+ key_filenames = []
+ elif isinstance(key_filename, str):
+ key_filenames = [key_filename]
+ else:
+ key_filenames = key_filename
+
+ self._auth(
+ username,
+ password,
+ pkey,
+ key_filenames,
+ allow_agent,
+ look_for_keys,
+ gss_auth,
+ gss_kex,
+ gss_deleg_creds,
+ t.gss_host,
+ passphrase,
+ )
def close(self):
"""
@@ -304,10 +508,23 @@ class SSHClient(ClosingContextManager):
reliable. Failure to explicitly close your client after use may
lead to end-of-process hangs!
"""
- pass
+ if self._transport is None:
+ return
+ self._transport.close()
+ self._transport = None
- def exec_command(self, command, bufsize=-1, timeout=None, get_pty=False,
- environment=None):
+ if self._agent is not None:
+ self._agent.close()
+ self._agent = None
+
+ def exec_command(
+ self,
+ command,
+ bufsize=-1,
+ timeout=None,
+ get_pty=False,
+ environment=None,
+ ):
"""
Execute a command on the SSH server. A new `.Channel` is opened and
the requested command is executed. The command's input and output
@@ -340,10 +557,27 @@ class SSHClient(ClosingContextManager):
.. versionchanged:: 1.10
Added the ``get_pty`` kwarg.
"""
- pass
-
- def invoke_shell(self, term='vt100', width=80, height=24, width_pixels=
- 0, height_pixels=0, environment=None):
+ chan = self._transport.open_session(timeout=timeout)
+ if get_pty:
+ chan.get_pty()
+ chan.settimeout(timeout)
+ if environment:
+ chan.update_environment(environment)
+ chan.exec_command(command)
+ stdin = chan.makefile_stdin("wb", bufsize)
+ stdout = chan.makefile("r", bufsize)
+ stderr = chan.makefile_stderr("r", bufsize)
+ return stdin, stdout, stderr
+
+ def invoke_shell(
+ self,
+ term="vt100",
+ width=80,
+ height=24,
+ width_pixels=0,
+ height_pixels=0,
+ environment=None,
+ ):
"""
Start an interactive shell session on the SSH server. A new `.Channel`
is opened and connected to a pseudo-terminal using the requested
@@ -360,7 +594,10 @@ class SSHClient(ClosingContextManager):
:raises: `.SSHException` -- if the server fails to invoke a shell
"""
- pass
+ chan = self._transport.open_session()
+ chan.get_pty(term, width, height, width_pixels, height_pixels)
+ chan.invoke_shell()
+ return chan
def open_sftp(self):
"""
@@ -368,7 +605,7 @@ class SSHClient(ClosingContextManager):
:return: a new `.SFTPClient` session object
"""
- pass
+ return self._transport.open_sftp_client()
def get_transport(self):
"""
@@ -378,7 +615,7 @@ class SSHClient(ClosingContextManager):
:return: the `.Transport` for this connection
"""
- pass
+ return self._transport
def _key_from_filepath(self, filename, klass, password):
"""
@@ -389,11 +626,43 @@ class SSHClient(ClosingContextManager):
- Otherwise, the filename is assumed to be a private key, and the
matching public cert will be loaded if it exists.
"""
- pass
-
- def _auth(self, username, password, pkey, key_filenames, allow_agent,
- look_for_keys, gss_auth, gss_kex, gss_deleg_creds, gss_host, passphrase
- ):
+ cert_suffix = "-cert.pub"
+ # Assume privkey, not cert, by default
+ if filename.endswith(cert_suffix):
+ key_path = filename[: -len(cert_suffix)]
+ cert_path = filename
+ else:
+ key_path = filename
+ cert_path = filename + cert_suffix
+ # Blindly try the key path; if no private key, nothing will work.
+ key = klass.from_private_key_file(key_path, password)
+ # TODO: change this to 'Loading' instead of 'Trying' sometime; probably
+ # when #387 is released, since this is a critical log message users are
+ # likely testing/filtering for (bah.)
+ msg = "Trying discovered key {} in {}".format(
+ hexlify(key.get_fingerprint()), key_path
+ )
+ self._log(DEBUG, msg)
+ # Attempt to load cert if it exists.
+ if os.path.isfile(cert_path):
+ key.load_certificate(cert_path)
+ self._log(DEBUG, "Adding public certificate {}".format(cert_path))
+ return key
+
+ def _auth(
+ self,
+ username,
+ password,
+ pkey,
+ key_filenames,
+ allow_agent,
+ look_for_keys,
+ gss_auth,
+ gss_kex,
+ gss_deleg_creds,
+ gss_host,
+ passphrase,
+ ):
"""
Try, in order:
@@ -407,7 +676,150 @@ class SSHClient(ClosingContextManager):
isn't also given], or for two-factor authentication [for which it is
required].)
"""
- pass
+ saved_exception = None
+ two_factor = False
+ allowed_types = set()
+ two_factor_types = {"keyboard-interactive", "password"}
+ if passphrase is None and password is not None:
+ passphrase = password
+
+ # If GSS-API support and GSS-PI Key Exchange was performed, we attempt
+ # authentication with gssapi-keyex.
+ if gss_kex and self._transport.gss_kex_used:
+ try:
+ self._transport.auth_gssapi_keyex(username)
+ return
+ except Exception as e:
+ saved_exception = e
+
+ # Try GSS-API authentication (gssapi-with-mic) only if GSS-API Key
+ # Exchange is not performed, because if we use GSS-API for the key
+ # exchange, there is already a fully established GSS-API context, so
+ # why should we do that again?
+ if gss_auth:
+ try:
+ return self._transport.auth_gssapi_with_mic(
+ username, gss_host, gss_deleg_creds
+ )
+ except Exception as e:
+ saved_exception = e
+
+ if pkey is not None:
+ try:
+ self._log(
+ DEBUG,
+ "Trying SSH key {}".format(
+ hexlify(pkey.get_fingerprint())
+ ),
+ )
+ allowed_types = set(
+ self._transport.auth_publickey(username, pkey)
+ )
+ two_factor = allowed_types & two_factor_types
+ if not two_factor:
+ return
+ except SSHException as e:
+ saved_exception = e
+
+ if not two_factor:
+ for key_filename in key_filenames:
+ # TODO 4.0: leverage PKey.from_path() if we don't end up just
+ # killing SSHClient entirely
+ for pkey_class in (RSAKey, DSSKey, ECDSAKey, Ed25519Key):
+ try:
+ key = self._key_from_filepath(
+ key_filename, pkey_class, passphrase
+ )
+ allowed_types = set(
+ self._transport.auth_publickey(username, key)
+ )
+ two_factor = allowed_types & two_factor_types
+ if not two_factor:
+ return
+ break
+ except SSHException as e:
+ saved_exception = e
+
+ if not two_factor and allow_agent:
+ if self._agent is None:
+ self._agent = Agent()
+
+ for key in self._agent.get_keys():
+ try:
+ id_ = hexlify(key.get_fingerprint())
+ self._log(DEBUG, "Trying SSH agent key {}".format(id_))
+ # for 2-factor auth a successfully auth'd key password
+ # will return an allowed 2fac auth method
+ allowed_types = set(
+ self._transport.auth_publickey(username, key)
+ )
+ two_factor = allowed_types & two_factor_types
+ if not two_factor:
+ return
+ break
+ except SSHException as e:
+ saved_exception = e
+
+ if not two_factor:
+ keyfiles = []
+
+ for keytype, name in [
+ (RSAKey, "rsa"),
+ (DSSKey, "dsa"),
+ (ECDSAKey, "ecdsa"),
+ (Ed25519Key, "ed25519"),
+ ]:
+ # ~/ssh/ is for windows
+ for directory in [".ssh", "ssh"]:
+ full_path = os.path.expanduser(
+ "~/{}/id_{}".format(directory, name)
+ )
+ if os.path.isfile(full_path):
+ # TODO: only do this append if below did not run
+ keyfiles.append((keytype, full_path))
+ if os.path.isfile(full_path + "-cert.pub"):
+ keyfiles.append((keytype, full_path + "-cert.pub"))
+
+ if not look_for_keys:
+ keyfiles = []
+
+ for pkey_class, filename in keyfiles:
+ try:
+ key = self._key_from_filepath(
+ filename, pkey_class, passphrase
+ )
+ # for 2-factor auth a successfully auth'd key will result
+ # in ['password']
+ allowed_types = set(
+ self._transport.auth_publickey(username, key)
+ )
+ two_factor = allowed_types & two_factor_types
+ if not two_factor:
+ return
+ break
+ except (SSHException, IOError) as e:
+ saved_exception = e
+
+ if password is not None:
+ try:
+ self._transport.auth_password(username, password)
+ return
+ except SSHException as e:
+ saved_exception = e
+ elif two_factor:
+ try:
+ self._transport.auth_interactive_dumb(username)
+ return
+ except SSHException as e:
+ saved_exception = e
+
+ # if we got an auth-failed exception earlier, re-raise it
+ if saved_exception is not None:
+ raise saved_exception
+ raise SSHException("No authentication methods available")
+
+ def _log(self, level, msg):
+ self._transport._log(level, msg)
class MissingHostKeyPolicy:
@@ -437,6 +849,17 @@ class AutoAddPolicy(MissingHostKeyPolicy):
local `.HostKeys` object, and saving it. This is used by `.SSHClient`.
"""
+ def missing_host_key(self, client, hostname, key):
+ client._host_keys.add(hostname, key.get_name(), key)
+ if client._host_keys_filename is not None:
+ client.save_host_keys(client._host_keys_filename)
+ client._log(
+ DEBUG,
+ "Adding {} host key for {}: {}".format(
+ key.get_name(), hostname, hexlify(key.get_fingerprint())
+ ),
+ )
+
class RejectPolicy(MissingHostKeyPolicy):
"""
@@ -444,9 +867,27 @@ class RejectPolicy(MissingHostKeyPolicy):
used by `.SSHClient`.
"""
+ def missing_host_key(self, client, hostname, key):
+ client._log(
+ DEBUG,
+ "Rejecting {} host key for {}: {}".format(
+ key.get_name(), hostname, hexlify(key.get_fingerprint())
+ ),
+ )
+ raise SSHException(
+ "Server {!r} not found in known_hosts".format(hostname)
+ )
+
class WarningPolicy(MissingHostKeyPolicy):
"""
Policy for logging a Python-style warning for an unknown host key, but
accepting it. This is used by `.SSHClient`.
"""
+
+ def missing_host_key(self, client, hostname, key):
+ warnings.warn(
+ "Unknown {} host key for {}: {}".format(
+ key.get_name(), hostname, hexlify(key.get_fingerprint())
+ )
+ )
diff --git a/paramiko/common.py b/paramiko/common.py
index 29e86a9d..b57149b7 100644
--- a/paramiko/common.py
+++ b/paramiko/common.py
@@ -1,24 +1,90 @@
+# Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
"""
Common constants and global variables.
"""
import logging
import struct
-(MSG_DISCONNECT, MSG_IGNORE, MSG_UNIMPLEMENTED, MSG_DEBUG,
- MSG_SERVICE_REQUEST, MSG_SERVICE_ACCEPT, MSG_EXT_INFO) = range(1, 8)
-MSG_KEXINIT, MSG_NEWKEYS = range(20, 22)
-(MSG_USERAUTH_REQUEST, MSG_USERAUTH_FAILURE, MSG_USERAUTH_SUCCESS,
- MSG_USERAUTH_BANNER) = range(50, 54)
+
+#
+# Formerly of py3compat.py. May be fully delete'able with a deeper look?
+#
+
+
+def byte_chr(c):
+ assert isinstance(c, int)
+ return struct.pack("B", c)
+
+
+def byte_mask(c, mask):
+ assert isinstance(c, int)
+ return struct.pack("B", c & mask)
+
+
+def byte_ord(c):
+ # In case we're handed a string instead of an int.
+ if not isinstance(c, int):
+ c = ord(c)
+ return c
+
+
+(
+ MSG_DISCONNECT,
+ MSG_IGNORE,
+ MSG_UNIMPLEMENTED,
+ MSG_DEBUG,
+ MSG_SERVICE_REQUEST,
+ MSG_SERVICE_ACCEPT,
+ MSG_EXT_INFO,
+) = range(1, 8)
+(MSG_KEXINIT, MSG_NEWKEYS) = range(20, 22)
+(
+ MSG_USERAUTH_REQUEST,
+ MSG_USERAUTH_FAILURE,
+ MSG_USERAUTH_SUCCESS,
+ MSG_USERAUTH_BANNER,
+) = range(50, 54)
MSG_USERAUTH_PK_OK = 60
-MSG_USERAUTH_INFO_REQUEST, MSG_USERAUTH_INFO_RESPONSE = range(60, 62)
-MSG_USERAUTH_GSSAPI_RESPONSE, MSG_USERAUTH_GSSAPI_TOKEN = range(60, 62)
-(MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE, MSG_USERAUTH_GSSAPI_ERROR,
- MSG_USERAUTH_GSSAPI_ERRTOK, MSG_USERAUTH_GSSAPI_MIC) = range(63, 67)
+(MSG_USERAUTH_INFO_REQUEST, MSG_USERAUTH_INFO_RESPONSE) = range(60, 62)
+(MSG_USERAUTH_GSSAPI_RESPONSE, MSG_USERAUTH_GSSAPI_TOKEN) = range(60, 62)
+(
+ MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE,
+ MSG_USERAUTH_GSSAPI_ERROR,
+ MSG_USERAUTH_GSSAPI_ERRTOK,
+ MSG_USERAUTH_GSSAPI_MIC,
+) = range(63, 67)
HIGHEST_USERAUTH_MESSAGE_ID = 79
-MSG_GLOBAL_REQUEST, MSG_REQUEST_SUCCESS, MSG_REQUEST_FAILURE = range(80, 83)
-(MSG_CHANNEL_OPEN, MSG_CHANNEL_OPEN_SUCCESS, MSG_CHANNEL_OPEN_FAILURE,
- MSG_CHANNEL_WINDOW_ADJUST, MSG_CHANNEL_DATA, MSG_CHANNEL_EXTENDED_DATA,
- MSG_CHANNEL_EOF, MSG_CHANNEL_CLOSE, MSG_CHANNEL_REQUEST,
- MSG_CHANNEL_SUCCESS, MSG_CHANNEL_FAILURE) = range(90, 101)
+(MSG_GLOBAL_REQUEST, MSG_REQUEST_SUCCESS, MSG_REQUEST_FAILURE) = range(80, 83)
+(
+ MSG_CHANNEL_OPEN,
+ MSG_CHANNEL_OPEN_SUCCESS,
+ MSG_CHANNEL_OPEN_FAILURE,
+ MSG_CHANNEL_WINDOW_ADJUST,
+ MSG_CHANNEL_DATA,
+ MSG_CHANNEL_EXTENDED_DATA,
+ MSG_CHANNEL_EOF,
+ MSG_CHANNEL_CLOSE,
+ MSG_CHANNEL_REQUEST,
+ MSG_CHANNEL_SUCCESS,
+ MSG_CHANNEL_FAILURE,
+) = range(90, 101)
+
cMSG_DISCONNECT = byte_chr(MSG_DISCONNECT)
cMSG_IGNORE = byte_chr(MSG_IGNORE)
cMSG_UNIMPLEMENTED = byte_chr(MSG_UNIMPLEMENTED)
@@ -38,7 +104,8 @@ cMSG_USERAUTH_INFO_RESPONSE = byte_chr(MSG_USERAUTH_INFO_RESPONSE)
cMSG_USERAUTH_GSSAPI_RESPONSE = byte_chr(MSG_USERAUTH_GSSAPI_RESPONSE)
cMSG_USERAUTH_GSSAPI_TOKEN = byte_chr(MSG_USERAUTH_GSSAPI_TOKEN)
cMSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE = byte_chr(
- MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE)
+ MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE
+)
cMSG_USERAUTH_GSSAPI_ERROR = byte_chr(MSG_USERAUTH_GSSAPI_ERROR)
cMSG_USERAUTH_GSSAPI_ERRTOK = byte_chr(MSG_USERAUTH_GSSAPI_ERRTOK)
cMSG_USERAUTH_GSSAPI_MIC = byte_chr(MSG_USERAUTH_GSSAPI_MIC)
@@ -56,52 +123,95 @@ cMSG_CHANNEL_CLOSE = byte_chr(MSG_CHANNEL_CLOSE)
cMSG_CHANNEL_REQUEST = byte_chr(MSG_CHANNEL_REQUEST)
cMSG_CHANNEL_SUCCESS = byte_chr(MSG_CHANNEL_SUCCESS)
cMSG_CHANNEL_FAILURE = byte_chr(MSG_CHANNEL_FAILURE)
-MSG_NAMES = {MSG_DISCONNECT: 'disconnect', MSG_IGNORE: 'ignore',
- MSG_UNIMPLEMENTED: 'unimplemented', MSG_DEBUG: 'debug',
- MSG_SERVICE_REQUEST: 'service-request', MSG_SERVICE_ACCEPT:
- 'service-accept', MSG_KEXINIT: 'kexinit', MSG_EXT_INFO: 'ext-info',
- MSG_NEWKEYS: 'newkeys', (30): 'kex30', (31): 'kex31', (32): 'kex32', (
- 33): 'kex33', (34): 'kex34', (40): 'kex40', (41): 'kex41',
- MSG_USERAUTH_REQUEST: 'userauth-request', MSG_USERAUTH_FAILURE:
- 'userauth-failure', MSG_USERAUTH_SUCCESS: 'userauth-success',
- MSG_USERAUTH_BANNER: 'userauth--banner', MSG_USERAUTH_PK_OK:
- 'userauth-60(pk-ok/info-request)', MSG_USERAUTH_INFO_RESPONSE:
- 'userauth-info-response', MSG_GLOBAL_REQUEST: 'global-request',
- MSG_REQUEST_SUCCESS: 'request-success', MSG_REQUEST_FAILURE:
- 'request-failure', MSG_CHANNEL_OPEN: 'channel-open',
- MSG_CHANNEL_OPEN_SUCCESS: 'channel-open-success',
- MSG_CHANNEL_OPEN_FAILURE: 'channel-open-failure',
- MSG_CHANNEL_WINDOW_ADJUST: 'channel-window-adjust', MSG_CHANNEL_DATA:
- 'channel-data', MSG_CHANNEL_EXTENDED_DATA: 'channel-extended-data',
- MSG_CHANNEL_EOF: 'channel-eof', MSG_CHANNEL_CLOSE: 'channel-close',
- MSG_CHANNEL_REQUEST: 'channel-request', MSG_CHANNEL_SUCCESS:
- 'channel-success', MSG_CHANNEL_FAILURE: 'channel-failure',
- MSG_USERAUTH_GSSAPI_RESPONSE: 'userauth-gssapi-response',
- MSG_USERAUTH_GSSAPI_TOKEN: 'userauth-gssapi-token',
- MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE:
- 'userauth-gssapi-exchange-complete', MSG_USERAUTH_GSSAPI_ERROR:
- 'userauth-gssapi-error', MSG_USERAUTH_GSSAPI_ERRTOK:
- 'userauth-gssapi-error-token', MSG_USERAUTH_GSSAPI_MIC:
- 'userauth-gssapi-mic'}
+
+# for debugging:
+MSG_NAMES = {
+ MSG_DISCONNECT: "disconnect",
+ MSG_IGNORE: "ignore",
+ MSG_UNIMPLEMENTED: "unimplemented",
+ MSG_DEBUG: "debug",
+ MSG_SERVICE_REQUEST: "service-request",
+ MSG_SERVICE_ACCEPT: "service-accept",
+ MSG_KEXINIT: "kexinit",
+ MSG_EXT_INFO: "ext-info",
+ MSG_NEWKEYS: "newkeys",
+ 30: "kex30",
+ 31: "kex31",
+ 32: "kex32",
+ 33: "kex33",
+ 34: "kex34",
+ 40: "kex40",
+ 41: "kex41",
+ MSG_USERAUTH_REQUEST: "userauth-request",
+ MSG_USERAUTH_FAILURE: "userauth-failure",
+ MSG_USERAUTH_SUCCESS: "userauth-success",
+ MSG_USERAUTH_BANNER: "userauth--banner",
+ MSG_USERAUTH_PK_OK: "userauth-60(pk-ok/info-request)",
+ MSG_USERAUTH_INFO_RESPONSE: "userauth-info-response",
+ MSG_GLOBAL_REQUEST: "global-request",
+ MSG_REQUEST_SUCCESS: "request-success",
+ MSG_REQUEST_FAILURE: "request-failure",
+ MSG_CHANNEL_OPEN: "channel-open",
+ MSG_CHANNEL_OPEN_SUCCESS: "channel-open-success",
+ MSG_CHANNEL_OPEN_FAILURE: "channel-open-failure",
+ MSG_CHANNEL_WINDOW_ADJUST: "channel-window-adjust",
+ MSG_CHANNEL_DATA: "channel-data",
+ MSG_CHANNEL_EXTENDED_DATA: "channel-extended-data",
+ MSG_CHANNEL_EOF: "channel-eof",
+ MSG_CHANNEL_CLOSE: "channel-close",
+ MSG_CHANNEL_REQUEST: "channel-request",
+ MSG_CHANNEL_SUCCESS: "channel-success",
+ MSG_CHANNEL_FAILURE: "channel-failure",
+ MSG_USERAUTH_GSSAPI_RESPONSE: "userauth-gssapi-response",
+ MSG_USERAUTH_GSSAPI_TOKEN: "userauth-gssapi-token",
+ MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE: "userauth-gssapi-exchange-complete",
+ MSG_USERAUTH_GSSAPI_ERROR: "userauth-gssapi-error",
+ MSG_USERAUTH_GSSAPI_ERRTOK: "userauth-gssapi-error-token",
+ MSG_USERAUTH_GSSAPI_MIC: "userauth-gssapi-mic",
+}
+
+
+# authentication request return codes:
AUTH_SUCCESSFUL, AUTH_PARTIALLY_SUCCESSFUL, AUTH_FAILED = range(3)
-(OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED,
- OPEN_FAILED_CONNECT_FAILED, OPEN_FAILED_UNKNOWN_CHANNEL_TYPE,
- OPEN_FAILED_RESOURCE_SHORTAGE) = range(0, 5)
-CONNECTION_FAILED_CODE = {(1): 'Administratively prohibited', (2):
- 'Connect failed', (3): 'Unknown channel type', (4): 'Resource shortage'}
-(DISCONNECT_SERVICE_NOT_AVAILABLE, DISCONNECT_AUTH_CANCELLED_BY_USER,
- DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE) = 7, 13, 14
+
+
+# channel request failed reasons:
+(
+ OPEN_SUCCEEDED,
+ OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED,
+ OPEN_FAILED_CONNECT_FAILED,
+ OPEN_FAILED_UNKNOWN_CHANNEL_TYPE,
+ OPEN_FAILED_RESOURCE_SHORTAGE,
+) = range(0, 5)
+
+
+CONNECTION_FAILED_CODE = {
+ 1: "Administratively prohibited",
+ 2: "Connect failed",
+ 3: "Unknown channel type",
+ 4: "Resource shortage",
+}
+
+
+(
+ DISCONNECT_SERVICE_NOT_AVAILABLE,
+ DISCONNECT_AUTH_CANCELLED_BY_USER,
+ DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE,
+) = (7, 13, 14)
+
zero_byte = byte_chr(0)
one_byte = byte_chr(1)
four_byte = byte_chr(4)
-max_byte = byte_chr(255)
+max_byte = byte_chr(0xFF)
cr_byte = byte_chr(13)
linefeed_byte = byte_chr(10)
crlf = cr_byte + linefeed_byte
cr_byte_value = 13
linefeed_byte_value = 10
-xffffffff = 4294967295
-x80000000 = 2147483648
+
+
+xffffffff = 0xFFFFFFFF
+x80000000 = 0x80000000
o666 = 438
o660 = 432
o644 = 420
@@ -109,14 +219,27 @@ o600 = 384
o777 = 511
o700 = 448
o70 = 56
+
DEBUG = logging.DEBUG
INFO = logging.INFO
WARNING = logging.WARNING
ERROR = logging.ERROR
CRITICAL = logging.CRITICAL
+
+# Common IO/select/etc sleep period, in seconds
io_sleep = 0.01
-DEFAULT_WINDOW_SIZE = 64 * 2 ** 15
-DEFAULT_MAX_PACKET_SIZE = 2 ** 15
-MIN_WINDOW_SIZE = 2 ** 15
-MIN_PACKET_SIZE = 2 ** 12
-MAX_WINDOW_SIZE = 2 ** 32 - 1
+
+DEFAULT_WINDOW_SIZE = 64 * 2**15
+DEFAULT_MAX_PACKET_SIZE = 2**15
+
+# lower bound on the max packet size we'll accept from the remote host
+# Minimum packet size is 32768 bytes according to
+# http://www.ietf.org/rfc/rfc4254.txt
+MIN_WINDOW_SIZE = 2**15
+
+# However, according to http://www.ietf.org/rfc/rfc4253.txt it is perfectly
+# legal to accept a size much smaller, as OpenSSH client does as size 16384.
+MIN_PACKET_SIZE = 2**12
+
+# Max windows size according to http://www.ietf.org/rfc/rfc4254.txt
+MAX_WINDOW_SIZE = 2**32 - 1
diff --git a/paramiko/compress.py b/paramiko/compress.py
index 64c87ad4..18ff4843 100644
--- a/paramiko/compress.py
+++ b/paramiko/compress.py
@@ -1,12 +1,31 @@
+# Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
"""
Compression implementations for a Transport.
"""
+
import zlib
class ZlibCompressor:
-
def __init__(self):
+ # Use the default level of zlib compression
self.z = zlib.compressobj()
def __call__(self, data):
@@ -14,7 +33,6 @@ class ZlibCompressor:
class ZlibDecompressor:
-
def __init__(self):
self.z = zlib.decompressobj()
diff --git a/paramiko/config.py b/paramiko/config.py
index 3301afef..8ab55c64 100644
--- a/paramiko/config.py
+++ b/paramiko/config.py
@@ -1,6 +1,26 @@
+# Copyright (C) 2006-2007 Robey Pointer <robeypointer@gmail.com>
+# Copyright (C) 2012 Olle Lundberg <geek@nerd.sh>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
"""
Configuration file (aka ``ssh_config``) support.
"""
+
import fnmatch
import getpass
import os
@@ -10,12 +30,16 @@ import socket
from hashlib import sha1
from io import StringIO
from functools import partial
+
invoke, invoke_import_error = None, None
try:
import invoke
except ImportError as e:
invoke_import_error = e
+
from .ssh_exception import CouldNotCanonicalize, ConfigParseError
+
+
SSH_PORT = 22
@@ -29,12 +53,21 @@ class SSHConfig:
.. versionadded:: 1.6
"""
- SETTINGS_REGEX = re.compile('(\\w+)(?:\\s*=\\s*|\\s+)(.+)')
- TOKENS_BY_CONFIG_KEY = {'controlpath': ['%C', '%h', '%l', '%L', '%n',
- '%p', '%r', '%u'], 'hostname': ['%h'], 'identityfile': ['%C', '~',
- '%d', '%h', '%l', '%u', '%r'], 'proxycommand': ['~', '%h', '%p',
- '%r'], 'proxyjump': ['%h', '%p', '%r'], 'match-exec': ['%C', '%d',
- '%h', '%L', '%l', '%n', '%p', '%r', '%u']}
+
+ SETTINGS_REGEX = re.compile(r"(\w+)(?:\s*=\s*|\s+)(.+)")
+
+ # TODO: do a full scan of ssh.c & friends to make sure we're fully
+ # compatible across the board, e.g. OpenSSH 8.1 added %n to ProxyCommand.
+ TOKENS_BY_CONFIG_KEY = {
+ "controlpath": ["%C", "%h", "%l", "%L", "%n", "%p", "%r", "%u"],
+ "hostname": ["%h"],
+ "identityfile": ["%C", "~", "%d", "%h", "%l", "%u", "%r"],
+ "proxycommand": ["~", "%h", "%p", "%r"],
+ "proxyjump": ["%h", "%p", "%r"],
+ # Doesn't seem worth making this 'special' for now, it will fit well
+ # enough (no actual match-exec config key to be confused with).
+ "match-exec": ["%C", "%d", "%h", "%L", "%l", "%n", "%p", "%r", "%u"],
+ }
def __init__(self):
"""
@@ -64,7 +97,7 @@ class SSHConfig:
.. versionadded:: 2.7
"""
- pass
+ return cls.from_file(StringIO(text))
@classmethod
def from_path(cls, path):
@@ -73,7 +106,8 @@ class SSHConfig:
.. versionadded:: 2.7
"""
- pass
+ with open(path) as flo:
+ return cls.from_file(flo)
@classmethod
def from_file(cls, flo):
@@ -82,7 +116,9 @@ class SSHConfig:
.. versionadded:: 2.7
"""
- pass
+ obj = cls()
+ obj.parse(flo)
+ return obj
def parse(self, file_obj):
"""
@@ -90,7 +126,59 @@ class SSHConfig:
:param file_obj: a file-like object to read the config file from
"""
- pass
+ # Start out w/ implicit/anonymous global host-like block to hold
+ # anything not contained by an explicit one.
+ context = {"host": ["*"], "config": {}}
+ for line in file_obj:
+ # Strip any leading or trailing whitespace from the line.
+ # Refer to https://github.com/paramiko/paramiko/issues/499
+ line = line.strip()
+ # Skip blanks, comments
+ if not line or line.startswith("#"):
+ continue
+
+ # Parse line into key, value
+ match = re.match(self.SETTINGS_REGEX, line)
+ if not match:
+ raise ConfigParseError("Unparsable line {}".format(line))
+ key = match.group(1).lower()
+ value = match.group(2)
+
+ # Host keyword triggers switch to new block/context
+ if key in ("host", "match"):
+ self._config.append(context)
+ context = {"config": {}}
+ if key == "host":
+ # TODO 4.0: make these real objects or at least name this
+ # "hosts" to acknowledge it's an iterable. (Doing so prior
+ # to 3.0, despite it being a private API, feels bad -
+ # surely such an old codebase has folks actually relying on
+ # these keys.)
+ context["host"] = self._get_hosts(value)
+ else:
+ context["matches"] = self._get_matches(value)
+ # Special-case for noop ProxyCommands
+ elif key == "proxycommand" and value.lower() == "none":
+ # Store 'none' as None - not as a string implying that the
+ # proxycommand is the literal shell command "none"!
+ context["config"][key] = None
+ # All other keywords get stored, directly or via append
+ else:
+ if value.startswith('"') and value.endswith('"'):
+ value = value[1:-1]
+
+ # identityfile, localforward, remoteforward keys are special
+ # cases, since they are allowed to be specified multiple times
+ # and they should be tried in order of specification.
+ if key in ["identityfile", "localforward", "remoteforward"]:
+ if key in context["config"]:
+ context["config"][key].append(value)
+ else:
+ context["config"][key] = [value]
+ elif key not in context["config"]:
+ context["config"][key] = value
+ # Store last 'open' block and we're done
+ self._config.append(context)
def lookup(self, hostname):
"""
@@ -133,7 +221,65 @@ class SSHConfig:
.. versionchanged:: 3.3
Added ``Match final`` support.
"""
- pass
+ # First pass
+ options = self._lookup(hostname=hostname)
+ # Inject HostName if it was not set (this used to be done incidentally
+ # during tokenization, for some reason).
+ if "hostname" not in options:
+ options["hostname"] = hostname
+ # Handle canonicalization
+ canon = options.get("canonicalizehostname", None) in ("yes", "always")
+ maxdots = int(options.get("canonicalizemaxdots", 1))
+ if canon and hostname.count(".") <= maxdots:
+ # NOTE: OpenSSH manpage does not explicitly state this, but its
+ # implementation for CanonicalDomains is 'split on any whitespace'.
+ domains = options["canonicaldomains"].split()
+ hostname = self.canonicalize(hostname, options, domains)
+ # Overwrite HostName again here (this is also what OpenSSH does)
+ options["hostname"] = hostname
+ options = self._lookup(
+ hostname, options, canonical=True, final=True
+ )
+ else:
+ options = self._lookup(
+ hostname, options, canonical=False, final=True
+ )
+ return options
+
+ def _lookup(self, hostname, options=None, canonical=False, final=False):
+ # Init
+ if options is None:
+ options = SSHConfigDict()
+ # Iterate all stanzas, applying any that match, in turn (so that things
+ # like Match can reference currently understood state)
+ for context in self._config:
+ if not (
+ self._pattern_matches(context.get("host", []), hostname)
+ or self._does_match(
+ context.get("matches", []),
+ hostname,
+ canonical,
+ final,
+ options,
+ )
+ ):
+ continue
+ for key, value in context["config"].items():
+ if key not in options:
+ # Create a copy of the original value,
+ # else it will reference the original list
+ # in self._config and update that value too
+ # when the extend() is being called.
+ options[key] = value[:] if value is not None else value
+ elif key == "identityfile":
+ options[key].extend(
+ x for x in value if x not in options[key]
+ )
+ if final:
+ # Expand variables in resulting values
+ # (besides 'Match exec' which was already handled above)
+ options = self._expand_variables(options, hostname)
+ return options
def canonicalize(self, hostname, options, domains):
"""
@@ -147,14 +293,120 @@ class SSHConfig:
.. versionadded:: 2.7
"""
- pass
+ found = False
+ for domain in domains:
+ candidate = "{}.{}".format(hostname, domain)
+ family_specific = _addressfamily_host_lookup(candidate, options)
+ if family_specific is not None:
+ # TODO: would we want to dig deeper into other results? e.g. to
+ # find something that satisfies PermittedCNAMEs when that is
+ # implemented?
+ found = family_specific[0]
+ else:
+ # TODO: what does ssh use here and is there a reason to use
+ # that instead of gethostbyname?
+ try:
+ found = socket.gethostbyname(candidate)
+ except socket.gaierror:
+ pass
+ if found:
+ # TODO: follow CNAME (implied by found != candidate?) if
+ # CanonicalizePermittedCNAMEs allows it
+ return candidate
+ # If we got here, it means canonicalization failed.
+ # When CanonicalizeFallbackLocal is undefined or 'yes', we just spit
+ # back the original hostname.
+ if options.get("canonicalizefallbacklocal", "yes") == "yes":
+ return hostname
+ # And here, we failed AND fallback was set to a non-yes value, so we
+ # need to get mad.
+ raise CouldNotCanonicalize(hostname)
def get_hostnames(self):
"""
Return the set of literal hostnames defined in the SSH config (both
explicit hostnames and wildcard entries).
"""
- pass
+ hosts = set()
+ for entry in self._config:
+ hosts.update(entry["host"])
+ return hosts
+
+ def _pattern_matches(self, patterns, target):
+ # Convenience auto-splitter if not already a list
+ if hasattr(patterns, "split"):
+ patterns = patterns.split(",")
+ match = False
+ for pattern in patterns:
+ # Short-circuit if target matches a negated pattern
+ if pattern.startswith("!") and fnmatch.fnmatch(
+ target, pattern[1:]
+ ):
+ return False
+ # Flag a match, but continue (in case of later negation) if regular
+ # match occurs
+ elif fnmatch.fnmatch(target, pattern):
+ match = True
+ return match
+
+ def _does_match(
+ self, match_list, target_hostname, canonical, final, options
+ ):
+ matched = []
+ candidates = match_list[:]
+ local_username = getpass.getuser()
+ while candidates:
+ candidate = candidates.pop(0)
+ passed = None
+ # Obtain latest host/user value every loop, so later Match may
+ # reference values assigned within a prior Match.
+ configured_host = options.get("hostname", None)
+ configured_user = options.get("user", None)
+ type_, param = candidate["type"], candidate["param"]
+ # Canonical is a hard pass/fail based on whether this is a
+ # canonicalized re-lookup.
+ if type_ == "canonical":
+ if self._should_fail(canonical, candidate):
+ return False
+ if type_ == "final":
+ passed = final
+ # The parse step ensures we only see this by itself or after
+ # canonical, so it's also an easy hard pass. (No negation here as
+ # that would be uh, pretty weird?)
+ elif type_ == "all":
+ return True
+ # From here, we are testing various non-hard criteria,
+ # short-circuiting only on fail
+ elif type_ == "host":
+ hostval = configured_host or target_hostname
+ passed = self._pattern_matches(param, hostval)
+ elif type_ == "originalhost":
+ passed = self._pattern_matches(param, target_hostname)
+ elif type_ == "user":
+ user = configured_user or local_username
+ passed = self._pattern_matches(param, user)
+ elif type_ == "localuser":
+ passed = self._pattern_matches(param, local_username)
+ elif type_ == "exec":
+ exec_cmd = self._tokenize(
+ options, target_hostname, "match-exec", param
+ )
+ # This is the laziest spot in which we can get mad about an
+ # inability to import Invoke.
+ if invoke is None:
+ raise invoke_import_error
+ # Like OpenSSH, we 'redirect' stdout but let stderr bubble up
+ passed = invoke.run(exec_cmd, hide="stdout", warn=True).ok
+ # Tackle any 'passed, but was negated' results from above
+ if passed is not None and self._should_fail(passed, candidate):
+ return False
+ # Made it all the way here? Everything matched!
+ matched.append(candidate)
+ # Did anything match? (To be treated as bool, usually.)
+ return matched
+
+ def _should_fail(self, would_pass, candidate):
+ return would_pass if candidate["negate"] else not would_pass
def _tokenize(self, config, target_hostname, key, value):
"""
@@ -167,7 +419,56 @@ class SSHConfig:
:returns: The tokenized version of the input ``value`` string.
"""
- pass
+ allowed_tokens = self._allowed_tokens(key)
+ # Short-circuit if no tokenization possible
+ if not allowed_tokens:
+ return value
+ # Obtain potentially configured hostname, for use with %h.
+ # Special-case where we are tokenizing the hostname itself, to avoid
+ # replacing %h with a %h-bearing value, etc.
+ configured_hostname = target_hostname
+ if key != "hostname":
+ configured_hostname = config.get("hostname", configured_hostname)
+ # Ditto the rest of the source values
+ if "port" in config:
+ port = config["port"]
+ else:
+ port = SSH_PORT
+ user = getpass.getuser()
+ if "user" in config:
+ remoteuser = config["user"]
+ else:
+ remoteuser = user
+ local_hostname = socket.gethostname().split(".")[0]
+ local_fqdn = LazyFqdn(config, local_hostname)
+ homedir = os.path.expanduser("~")
+ tohash = local_hostname + target_hostname + repr(port) + remoteuser
+ # The actual tokens!
+ replacements = {
+ # TODO: %%???
+ "%C": sha1(tohash.encode()).hexdigest(),
+ "%d": homedir,
+ "%h": configured_hostname,
+ # TODO: %i?
+ "%L": local_hostname,
+ "%l": local_fqdn,
+ # also this is pseudo buggy when not in Match exec mode so document
+ # that. also WHY is that the case?? don't we do all of this late?
+ "%n": target_hostname,
+ "%p": port,
+ "%r": remoteuser,
+ # TODO: %T? don't believe this is possible however
+ "%u": user,
+ "~": homedir,
+ }
+ # Do the thing with the stuff
+ tokenized = value
+ for find, replace in replacements.items():
+ if find not in allowed_tokens:
+ continue
+ tokenized = tokenized.replace(find, str(replace))
+ # TODO: log? eg that value -> tokenized
+ return tokenized
def _allowed_tokens(self, key):
"""
@@ -178,7 +479,7 @@ class SSHConfig:
preserve as-strict-as-possible compatibility with OpenSSH, which
for whatever reason only applies some tokens to some config keys.
"""
- pass
+ return self.TOKENS_BY_CONFIG_KEY.get(key, [])
def _expand_variables(self, config, target_hostname):
"""
@@ -190,13 +491,25 @@ class SSHConfig:
:param dict config: the currently parsed config
:param str hostname: the hostname whose config is being looked up
"""
- pass
+ for k in config:
+ if config[k] is None:
+ continue
+ tokenizer = partial(self._tokenize, config, target_hostname, k)
+ if isinstance(config[k], list):
+ for i, value in enumerate(config[k]):
+ config[k][i] = tokenizer(value)
+ else:
+ config[k] = tokenizer(config[k])
+ return config
def _get_hosts(self, host):
"""
Return a list of host_names from host value.
"""
- pass
+ try:
+ return shlex.split(host)
+ except ValueError:
+ raise ConfigParseError("Unparsable host {}".format(host))
def _get_matches(self, match):
"""
@@ -204,7 +517,43 @@ class SSHConfig:
Performs some parse-time validation as well.
"""
- pass
+ matches = []
+ tokens = shlex.split(match)
+ while tokens:
+ match = {"type": None, "param": None, "negate": False}
+ type_ = tokens.pop(0)
+ # Handle per-keyword negation
+ if type_.startswith("!"):
+ match["negate"] = True
+ type_ = type_[1:]
+ match["type"] = type_
+ # all/canonical have no params (everything else does)
+ if type_ in ("all", "canonical", "final"):
+ matches.append(match)
+ continue
+ if not tokens:
+ raise ConfigParseError(
+ "Missing parameter to Match '{}' keyword".format(type_)
+ )
+ match["param"] = tokens.pop(0)
+ matches.append(match)
+ # Perform some (easier to do now than in the middle) validation that is
+ # better handled here than at lookup time.
+ keywords = [x["type"] for x in matches]
+ if "all" in keywords:
+ allowable = ("all", "canonical")
+ ok, bad = (
+ list(filter(lambda x: x in allowable, keywords)),
+ list(filter(lambda x: x not in allowable, keywords)),
+ )
+ err = None
+ if any(bad):
+ err = "Match does not allow 'all' mixed with anything but 'canonical'" # noqa
+ elif "canonical" in ok and ok.index("canonical") > ok.index("all"):
+ err = "Match does not allow 'all' before 'canonical'"
+ if err is not None:
+ raise ConfigParseError(err)
+ return matches
def _addressfamily_host_lookup(hostname, options):
@@ -225,7 +574,23 @@ def _addressfamily_host_lookup(hostname, options):
:param options: `SSHConfigDict` instance w/ parsed options.
:returns: ``getaddrinfo``-style tuples, or ``None``, depending.
"""
- pass
+ address_family = options.get("addressfamily", "any").lower()
+ if address_family == "any":
+ return
+ try:
+ family = socket.AF_INET6
+ if address_family == "inet":
+ family = socket.AF_INET
+ return socket.getaddrinfo(
+ hostname,
+ None,
+ family,
+ socket.SOCK_DGRAM,
+ socket.IPPROTO_IP,
+ socket.AI_CANONNAME,
+ )
+ except socket.gaierror:
+ pass
class LazyFqdn:
@@ -240,16 +605,28 @@ class LazyFqdn:
def __str__(self):
if self.fqdn is None:
+ #
+ # If the SSH config contains AddressFamily, use that when
+ # determining the local host's FQDN. Using socket.getfqdn() from
+ # the standard library is the most general solution, but can
+ # result in noticeable delays on some platforms when IPv6 is
+ # misconfigured or not available, as it calls getaddrinfo with no
+ # address family specified, so both IPv4 and IPv6 are checked.
+ #
+
+ # Handle specific option
fqdn = None
results = _addressfamily_host_lookup(self.host, self.config)
if results is not None:
for res in results:
af, socktype, proto, canonname, sa = res
- if canonname and '.' in canonname:
+ if canonname and "." in canonname:
fqdn = canonname
break
+ # Handle 'any' / unspecified / lookup failure
if fqdn is None:
fqdn = socket.getfqdn()
+ # Cache
self.fqdn = fqdn
return self.fqdn
@@ -302,7 +679,10 @@ class SSHConfigDict(dict):
.. versionadded:: 2.5
"""
- pass
+ val = self[key]
+ if isinstance(val, bool):
+ return val
+ return val.lower() == "yes"
def as_int(self, key):
"""
@@ -313,4 +693,4 @@ class SSHConfigDict(dict):
.. versionadded:: 2.5
"""
- pass
+ return int(self[key])
diff --git a/paramiko/dsskey.py b/paramiko/dsskey.py
index a2882c5e..5215d282 100644
--- a/paramiko/dsskey.py
+++ b/paramiko/dsskey.py
@@ -1,11 +1,34 @@
+# Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
"""
DSS keys.
"""
+
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import dsa
-from cryptography.hazmat.primitives.asymmetric.utils import decode_dss_signature, encode_dss_signature
+from cryptography.hazmat.primitives.asymmetric.utils import (
+ decode_dss_signature,
+ encode_dss_signature,
+)
+
from paramiko import util
from paramiko.common import zero_byte
from paramiko.ssh_exception import SSHException
@@ -19,10 +42,18 @@ class DSSKey(PKey):
Representation of a DSS key which can be used to sign an verify SSH2
data.
"""
- name = 'ssh-dss'
- def __init__(self, msg=None, data=None, filename=None, password=None,
- vals=None, file_obj=None):
+ name = "ssh-dss"
+
+ def __init__(
+ self,
+ msg=None,
+ data=None,
+ filename=None,
+ password=None,
+ vals=None,
+ file_obj=None,
+ ):
self.p = None
self.q = None
self.g = None
@@ -35,22 +66,138 @@ class DSSKey(PKey):
if filename is not None:
self._from_private_key_file(filename, password)
return
- if msg is None and data is not None:
+ if (msg is None) and (data is not None):
msg = Message(data)
if vals is not None:
self.p, self.q, self.g, self.y = vals
else:
- self._check_type_and_load_cert(msg=msg, key_type=self.name,
- cert_type=f'{self.name}-cert-v01@openssh.com')
+ self._check_type_and_load_cert(
+ msg=msg,
+ key_type=self.name,
+ cert_type=f"{self.name}-cert-v01@openssh.com",
+ )
self.p = msg.get_mpint()
self.q = msg.get_mpint()
self.g = msg.get_mpint()
self.y = msg.get_mpint()
self.size = util.bit_length(self.p)
+ def asbytes(self):
+ m = Message()
+ m.add_string(self.name)
+ m.add_mpint(self.p)
+ m.add_mpint(self.q)
+ m.add_mpint(self.g)
+ m.add_mpint(self.y)
+ return m.asbytes()
+
def __str__(self):
return self.asbytes()
+ @property
+ def _fields(self):
+ return (self.get_name(), self.p, self.q, self.g, self.y)
+
+ # TODO 4.0: remove
+ def get_name(self):
+ return self.name
+
+ def get_bits(self):
+ return self.size
+
+ def can_sign(self):
+ return self.x is not None
+
+ def sign_ssh_data(self, data, algorithm=None):
+ key = dsa.DSAPrivateNumbers(
+ x=self.x,
+ public_numbers=dsa.DSAPublicNumbers(
+ y=self.y,
+ parameter_numbers=dsa.DSAParameterNumbers(
+ p=self.p, q=self.q, g=self.g
+ ),
+ ),
+ ).private_key(backend=default_backend())
+ sig = key.sign(data, hashes.SHA1())
+ r, s = decode_dss_signature(sig)
+
+ m = Message()
+ m.add_string(self.name)
+ # apparently, in rare cases, r or s may be shorter than 20 bytes!
+ rstr = util.deflate_long(r, 0)
+ sstr = util.deflate_long(s, 0)
+ if len(rstr) < 20:
+ rstr = zero_byte * (20 - len(rstr)) + rstr
+ if len(sstr) < 20:
+ sstr = zero_byte * (20 - len(sstr)) + sstr
+ m.add_string(rstr + sstr)
+ return m
+
+ def verify_ssh_sig(self, data, msg):
+ if len(msg.asbytes()) == 40:
+ # spies.com bug: signature has no header
+ sig = msg.asbytes()
+ else:
+ kind = msg.get_text()
+ if kind != self.name:
+ return 0
+ sig = msg.get_binary()
+
+ # pull out (r, s) which are NOT encoded as mpints
+ sigR = util.inflate_long(sig[:20], 1)
+ sigS = util.inflate_long(sig[20:], 1)
+
+ signature = encode_dss_signature(sigR, sigS)
+
+ key = dsa.DSAPublicNumbers(
+ y=self.y,
+ parameter_numbers=dsa.DSAParameterNumbers(
+ p=self.p, q=self.q, g=self.g
+ ),
+ ).public_key(backend=default_backend())
+ try:
+ key.verify(signature, data, hashes.SHA1())
+ except InvalidSignature:
+ return False
+ else:
+ return True
+
+ def write_private_key_file(self, filename, password=None):
+ key = dsa.DSAPrivateNumbers(
+ x=self.x,
+ public_numbers=dsa.DSAPublicNumbers(
+ y=self.y,
+ parameter_numbers=dsa.DSAParameterNumbers(
+ p=self.p, q=self.q, g=self.g
+ ),
+ ),
+ ).private_key(backend=default_backend())
+
+ self._write_private_key_file(
+ filename,
+ key,
+ serialization.PrivateFormat.TraditionalOpenSSL,
+ password=password,
+ )
+
+ def write_private_key(self, file_obj, password=None):
+ key = dsa.DSAPrivateNumbers(
+ x=self.x,
+ public_numbers=dsa.DSAPublicNumbers(
+ y=self.y,
+ parameter_numbers=dsa.DSAParameterNumbers(
+ p=self.p, q=self.q, g=self.g
+ ),
+ ),
+ ).private_key(backend=default_backend())
+
+ self._write_private_key(
+ file_obj,
+ key,
+ serialization.PrivateFormat.TraditionalOpenSSL,
+ password=password,
+ )
+
@staticmethod
def generate(bits=1024, progress_func=None):
"""
@@ -61,4 +208,51 @@ class DSSKey(PKey):
:param progress_func: Unused
:return: new `.DSSKey` private key
"""
- pass
+ numbers = dsa.generate_private_key(
+ bits, backend=default_backend()
+ ).private_numbers()
+ key = DSSKey(
+ vals=(
+ numbers.public_numbers.parameter_numbers.p,
+ numbers.public_numbers.parameter_numbers.q,
+ numbers.public_numbers.parameter_numbers.g,
+ numbers.public_numbers.y,
+ )
+ )
+ key.x = numbers.x
+ return key
+
+ # ...internals...
+
+ def _from_private_key_file(self, filename, password):
+ data = self._read_private_key_file("DSA", filename, password)
+ self._decode_key(data)
+
+ def _from_private_key(self, file_obj, password):
+ data = self._read_private_key("DSA", file_obj, password)
+ self._decode_key(data)
+
+ def _decode_key(self, data):
+ pkformat, data = data
+ # private key file contains:
+ # DSAPrivateKey = { version = 0, p, q, g, y, x }
+ if pkformat == self._PRIVATE_KEY_FORMAT_ORIGINAL:
+ try:
+ keylist = BER(data).decode()
+ except BERException as e:
+ raise SSHException("Unable to parse key file: {}".format(e))
+ elif pkformat == self._PRIVATE_KEY_FORMAT_OPENSSH:
+ keylist = self._uint32_cstruct_unpack(data, "iiiii")
+ keylist = [0] + list(keylist)
+ else:
+ self._got_bad_key_format_id(pkformat)
+ if type(keylist) is not list or len(keylist) < 6 or keylist[0] != 0:
+ raise SSHException(
+ "not a valid DSA private key file (bad ber encoding)"
+ )
+ self.p = keylist[1]
+ self.q = keylist[2]
+ self.g = keylist[3]
+ self.y = keylist[4]
+ self.x = keylist[5]
+ self.size = util.bit_length(self.p)
diff --git a/paramiko/ecdsakey.py b/paramiko/ecdsakey.py
index 3c6f2ecf..6fd95fab 100644
--- a/paramiko/ecdsakey.py
+++ b/paramiko/ecdsakey.py
@@ -1,11 +1,34 @@
+# Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
"""
ECDSA keys
"""
+
from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import ec
-from cryptography.hazmat.primitives.asymmetric.utils import decode_dss_signature, encode_dss_signature
+from cryptography.hazmat.primitives.asymmetric.utils import (
+ decode_dss_signature,
+ encode_dss_signature,
+)
+
from paramiko.common import four_byte
from paramiko.message import Message
from paramiko.pkey import PKey
@@ -25,13 +48,18 @@ class _ECDSACurve:
def __init__(self, curve_class, nist_name):
self.nist_name = nist_name
self.key_length = curve_class.key_size
- self.key_format_identifier = 'ecdsa-sha2-' + self.nist_name
+
+ # Defined in RFC 5656 6.2
+ self.key_format_identifier = "ecdsa-sha2-" + self.nist_name
+
+ # Defined in RFC 5656 6.2.1
if self.key_length <= 256:
self.hash_object = hashes.SHA256
elif self.key_length <= 384:
self.hash_object = hashes.SHA384
else:
self.hash_object = hashes.SHA512
+
self.curve_class = curve_class
@@ -45,18 +73,50 @@ class _ECDSACurveSet:
def __init__(self, ecdsa_curves):
self.ecdsa_curves = ecdsa_curves
+ def get_key_format_identifier_list(self):
+ return [curve.key_format_identifier for curve in self.ecdsa_curves]
+
+ def get_by_curve_class(self, curve_class):
+ for curve in self.ecdsa_curves:
+ if curve.curve_class == curve_class:
+ return curve
+
+ def get_by_key_format_identifier(self, key_format_identifier):
+ for curve in self.ecdsa_curves:
+ if curve.key_format_identifier == key_format_identifier:
+ return curve
+
+ def get_by_key_length(self, key_length):
+ for curve in self.ecdsa_curves:
+ if curve.key_length == key_length:
+ return curve
+
class ECDSAKey(PKey):
"""
Representation of an ECDSA key which can be used to sign and verify SSH2
data.
"""
- _ECDSA_CURVES = _ECDSACurveSet([_ECDSACurve(ec.SECP256R1, 'nistp256'),
- _ECDSACurve(ec.SECP384R1, 'nistp384'), _ECDSACurve(ec.SECP521R1,
- 'nistp521')])
- def __init__(self, msg=None, data=None, filename=None, password=None,
- vals=None, file_obj=None, validate_point=True):
+ _ECDSA_CURVES = _ECDSACurveSet(
+ [
+ _ECDSACurve(ec.SECP256R1, "nistp256"),
+ _ECDSACurve(ec.SECP384R1, "nistp384"),
+ _ECDSACurve(ec.SECP521R1, "nistp521"),
+ ]
+ )
+
+ def __init__(
+ self,
+ msg=None,
+ data=None,
+ filename=None,
+ password=None,
+ vals=None,
+ file_obj=None,
+ # TODO 4.0: remove; it does nothing since porting to cryptography.io
+ validate_point=True,
+ ):
self.verifying_key = None
self.signing_key = None
self.public_blob = None
@@ -66,39 +126,139 @@ class ECDSAKey(PKey):
if filename is not None:
self._from_private_key_file(filename, password)
return
- if msg is None and data is not None:
+ if (msg is None) and (data is not None):
msg = Message(data)
if vals is not None:
self.signing_key, self.verifying_key = vals
c_class = self.signing_key.curve.__class__
self.ecdsa_curve = self._ECDSA_CURVES.get_by_curve_class(c_class)
else:
+ # Must set ecdsa_curve first; subroutines called herein may need to
+ # spit out our get_name(), which relies on this.
key_type = msg.get_text()
- suffix = '-cert-v01@openssh.com'
+ # But this also means we need to hand it a real key/curve
+ # identifier, so strip out any cert business. (NOTE: could push
+ # that into _ECDSACurveSet.get_by_key_format_identifier(), but it
+ # feels more correct to do it here?)
+ suffix = "-cert-v01@openssh.com"
if key_type.endswith(suffix):
- key_type = key_type[:-len(suffix)]
+ key_type = key_type[: -len(suffix)]
self.ecdsa_curve = self._ECDSA_CURVES.get_by_key_format_identifier(
- key_type)
+ key_type
+ )
key_types = self._ECDSA_CURVES.get_key_format_identifier_list()
- cert_types = ['{}-cert-v01@openssh.com'.format(x) for x in
- key_types]
- self._check_type_and_load_cert(msg=msg, key_type=key_types,
- cert_type=cert_types)
+ cert_types = [
+ "{}-cert-v01@openssh.com".format(x) for x in key_types
+ ]
+ self._check_type_and_load_cert(
+ msg=msg, key_type=key_types, cert_type=cert_types
+ )
curvename = msg.get_text()
if curvename != self.ecdsa_curve.nist_name:
- raise SSHException("Can't handle curve of type {}".format(
- curvename))
+ raise SSHException(
+ "Can't handle curve of type {}".format(curvename)
+ )
+
pointinfo = msg.get_binary()
try:
- key = ec.EllipticCurvePublicKey.from_encoded_point(self.
- ecdsa_curve.curve_class(), pointinfo)
+ key = ec.EllipticCurvePublicKey.from_encoded_point(
+ self.ecdsa_curve.curve_class(), pointinfo
+ )
self.verifying_key = key
except ValueError:
- raise SSHException('Invalid public key')
+ raise SSHException("Invalid public key")
+
+ @classmethod
+ def identifiers(cls):
+ return cls._ECDSA_CURVES.get_key_format_identifier_list()
+
+ # TODO 4.0: deprecate/remove
+ @classmethod
+ def supported_key_format_identifiers(cls):
+ return cls.identifiers()
+
+ def asbytes(self):
+ key = self.verifying_key
+ m = Message()
+ m.add_string(self.ecdsa_curve.key_format_identifier)
+ m.add_string(self.ecdsa_curve.nist_name)
+
+ numbers = key.public_numbers()
+
+ key_size_bytes = (key.curve.key_size + 7) // 8
+
+ x_bytes = deflate_long(numbers.x, add_sign_padding=False)
+ x_bytes = b"\x00" * (key_size_bytes - len(x_bytes)) + x_bytes
+
+ y_bytes = deflate_long(numbers.y, add_sign_padding=False)
+ y_bytes = b"\x00" * (key_size_bytes - len(y_bytes)) + y_bytes
+
+ point_str = four_byte + x_bytes + y_bytes
+ m.add_string(point_str)
+ return m.asbytes()
def __str__(self):
return self.asbytes()
+ @property
+ def _fields(self):
+ return (
+ self.get_name(),
+ self.verifying_key.public_numbers().x,
+ self.verifying_key.public_numbers().y,
+ )
+
+ def get_name(self):
+ return self.ecdsa_curve.key_format_identifier
+
+ def get_bits(self):
+ return self.ecdsa_curve.key_length
+
+ def can_sign(self):
+ return self.signing_key is not None
+
+ def sign_ssh_data(self, data, algorithm=None):
+ ecdsa = ec.ECDSA(self.ecdsa_curve.hash_object())
+ sig = self.signing_key.sign(data, ecdsa)
+ r, s = decode_dss_signature(sig)
+
+ m = Message()
+ m.add_string(self.ecdsa_curve.key_format_identifier)
+ m.add_string(self._sigencode(r, s))
+ return m
+
+ def verify_ssh_sig(self, data, msg):
+ if msg.get_text() != self.ecdsa_curve.key_format_identifier:
+ return False
+ sig = msg.get_binary()
+ sigR, sigS = self._sigdecode(sig)
+ signature = encode_dss_signature(sigR, sigS)
+
+ try:
+ self.verifying_key.verify(
+ signature, data, ec.ECDSA(self.ecdsa_curve.hash_object())
+ )
+ except InvalidSignature:
+ return False
+ else:
+ return True
+
+ def write_private_key_file(self, filename, password=None):
+ self._write_private_key_file(
+ filename,
+ self.signing_key,
+ serialization.PrivateFormat.TraditionalOpenSSL,
+ password=password,
+ )
+
+ def write_private_key(self, file_obj, password=None):
+ self._write_private_key(
+ file_obj,
+ self.signing_key,
+ serialization.PrivateFormat.TraditionalOpenSSL,
+ password=password,
+ )
+
@classmethod
def generate(cls, curve=ec.SECP256R1(), progress_func=None, bits=None):
"""
@@ -108,4 +268,72 @@ class ECDSAKey(PKey):
:param progress_func: Not used for this type of key.
:returns: A new private key (`.ECDSAKey`) object
"""
- pass
+ if bits is not None:
+ curve = cls._ECDSA_CURVES.get_by_key_length(bits)
+ if curve is None:
+ raise ValueError("Unsupported key length: {:d}".format(bits))
+ curve = curve.curve_class()
+
+ private_key = ec.generate_private_key(curve, backend=default_backend())
+ return ECDSAKey(vals=(private_key, private_key.public_key()))
+
+ # ...internals...
+
+ def _from_private_key_file(self, filename, password):
+ data = self._read_private_key_file("EC", filename, password)
+ self._decode_key(data)
+
+ def _from_private_key(self, file_obj, password):
+ data = self._read_private_key("EC", file_obj, password)
+ self._decode_key(data)
+
+ def _decode_key(self, data):
+ pkformat, data = data
+ if pkformat == self._PRIVATE_KEY_FORMAT_ORIGINAL:
+ try:
+ key = serialization.load_der_private_key(
+ data, password=None, backend=default_backend()
+ )
+ except (
+ ValueError,
+ AssertionError,
+ TypeError,
+ UnsupportedAlgorithm,
+ ) as e:
+ raise SSHException(str(e))
+ elif pkformat == self._PRIVATE_KEY_FORMAT_OPENSSH:
+ try:
+ msg = Message(data)
+ curve_name = msg.get_text()
+ verkey = msg.get_binary() # noqa: F841
+ sigkey = msg.get_mpint()
+ name = "ecdsa-sha2-" + curve_name
+ curve = self._ECDSA_CURVES.get_by_key_format_identifier(name)
+ if not curve:
+ raise SSHException("Invalid key curve identifier")
+ key = ec.derive_private_key(
+ sigkey, curve.curve_class(), default_backend()
+ )
+ except Exception as e:
+ # PKey._read_private_key_openssh() should check or return
+ # keytype - parsing could fail for any reason due to wrong type
+ raise SSHException(str(e))
+ else:
+ self._got_bad_key_format_id(pkformat)
+
+ self.signing_key = key
+ self.verifying_key = key.public_key()
+ curve_class = key.curve.__class__
+ self.ecdsa_curve = self._ECDSA_CURVES.get_by_curve_class(curve_class)
+
+ def _sigencode(self, r, s):
+ msg = Message()
+ msg.add_mpint(r)
+ msg.add_mpint(s)
+ return msg.asbytes()
+
+ def _sigdecode(self, sig):
+ msg = Message(sig)
+ r = msg.get_mpint()
+ s = msg.get_mpint()
+ return r, s
diff --git a/paramiko/ed25519key.py b/paramiko/ed25519key.py
index 3b22f73a..e5e81ac5 100644
--- a/paramiko/ed25519key.py
+++ b/paramiko/ed25519key.py
@@ -1,7 +1,26 @@
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
import bcrypt
+
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import Cipher
+
import nacl.signing
+
from paramiko.message import Message
from paramiko.pkey import PKey, OPENSSH_AUTH_MAGIC, _unpad_openssh
from paramiko.util import b
@@ -19,26 +38,175 @@ class Ed25519Key(PKey):
.. versionchanged:: 2.3
Added a ``file_obj`` parameter to match other key classes.
"""
- name = 'ssh-ed25519'
- def __init__(self, msg=None, data=None, filename=None, password=None,
- file_obj=None):
+ name = "ssh-ed25519"
+
+ def __init__(
+ self, msg=None, data=None, filename=None, password=None, file_obj=None
+ ):
self.public_blob = None
verifying_key = signing_key = None
if msg is None and data is not None:
msg = Message(data)
if msg is not None:
- self._check_type_and_load_cert(msg=msg, key_type=self.name,
- cert_type='ssh-ed25519-cert-v01@openssh.com')
+ self._check_type_and_load_cert(
+ msg=msg,
+ key_type=self.name,
+ cert_type="ssh-ed25519-cert-v01@openssh.com",
+ )
verifying_key = nacl.signing.VerifyKey(msg.get_binary())
elif filename is not None:
- with open(filename, 'r') as f:
- pkformat, data = self._read_private_key('OPENSSH', f)
+ with open(filename, "r") as f:
+ pkformat, data = self._read_private_key("OPENSSH", f)
elif file_obj is not None:
- pkformat, data = self._read_private_key('OPENSSH', file_obj)
+ pkformat, data = self._read_private_key("OPENSSH", file_obj)
+
if filename or file_obj:
signing_key = self._parse_signing_key_data(data, password)
+
if signing_key is None and verifying_key is None:
- raise ValueError('need a key')
+ raise ValueError("need a key")
+
self._signing_key = signing_key
self._verifying_key = verifying_key
+
+ def _parse_signing_key_data(self, data, password):
+ from paramiko.transport import Transport
+
+ # We may eventually want this to be usable for other key types, as
+ # OpenSSH moves to it, but for now this is just for Ed25519 keys.
+ # This format is described here:
+ # https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.key
+ # The description isn't totally complete, and I had to refer to the
+ # source for a full implementation.
+ message = Message(data)
+ if message.get_bytes(len(OPENSSH_AUTH_MAGIC)) != OPENSSH_AUTH_MAGIC:
+ raise SSHException("Invalid key")
+
+ ciphername = message.get_text()
+ kdfname = message.get_text()
+ kdfoptions = message.get_binary()
+ num_keys = message.get_int()
+
+ if kdfname == "none":
+ # kdfname of "none" must have an empty kdfoptions, the ciphername
+ # must be "none"
+ if kdfoptions or ciphername != "none":
+ raise SSHException("Invalid key")
+ elif kdfname == "bcrypt":
+ if not password:
+ raise PasswordRequiredException(
+ "Private key file is encrypted"
+ )
+ kdf = Message(kdfoptions)
+ bcrypt_salt = kdf.get_binary()
+ bcrypt_rounds = kdf.get_int()
+ else:
+ raise SSHException("Invalid key")
+
+ if ciphername != "none" and ciphername not in Transport._cipher_info:
+ raise SSHException("Invalid key")
+
+ public_keys = []
+ for _ in range(num_keys):
+ pubkey = Message(message.get_binary())
+ if pubkey.get_text() != self.name:
+ raise SSHException("Invalid key")
+ public_keys.append(pubkey.get_binary())
+
+ private_ciphertext = message.get_binary()
+ if ciphername == "none":
+ private_data = private_ciphertext
+ else:
+ cipher = Transport._cipher_info[ciphername]
+ key = bcrypt.kdf(
+ password=b(password),
+ salt=bcrypt_salt,
+ desired_key_bytes=cipher["key-size"] + cipher["block-size"],
+ rounds=bcrypt_rounds,
+ # We can't control how many rounds are on disk, so no sense
+ # warning about it.
+ ignore_few_rounds=True,
+ )
+ decryptor = Cipher(
+ cipher["class"](key[: cipher["key-size"]]),
+ cipher["mode"](key[cipher["key-size"] :]),
+ backend=default_backend(),
+ ).decryptor()
+ private_data = (
+ decryptor.update(private_ciphertext) + decryptor.finalize()
+ )
+
+ message = Message(_unpad_openssh(private_data))
+ if message.get_int() != message.get_int():
+ raise SSHException("Invalid key")
+
+ signing_keys = []
+ for i in range(num_keys):
+ if message.get_text() != self.name:
+ raise SSHException("Invalid key")
+ # A copy of the public key, again, ignore.
+ public = message.get_binary()
+ key_data = message.get_binary()
+ # The second half of the key data is yet another copy of the public
+ # key...
+ signing_key = nacl.signing.SigningKey(key_data[:32])
+ # Verify that all the public keys are the same...
+ assert (
+ signing_key.verify_key.encode()
+ == public
+ == public_keys[i]
+ == key_data[32:]
+ )
+ signing_keys.append(signing_key)
+ # Comment, ignore.
+ message.get_binary()
+
+ if len(signing_keys) != 1:
+ raise SSHException("Invalid key")
+ return signing_keys[0]
+
+ def asbytes(self):
+ if self.can_sign():
+ v = self._signing_key.verify_key
+ else:
+ v = self._verifying_key
+ m = Message()
+ m.add_string(self.name)
+ m.add_string(v.encode())
+ return m.asbytes()
+
+ @property
+ def _fields(self):
+ if self.can_sign():
+ v = self._signing_key.verify_key
+ else:
+ v = self._verifying_key
+ return (self.get_name(), v)
+
+ # TODO 4.0: remove
+ def get_name(self):
+ return self.name
+
+ def get_bits(self):
+ return 256
+
+ def can_sign(self):
+ return self._signing_key is not None
+
+ def sign_ssh_data(self, data, algorithm=None):
+ m = Message()
+ m.add_string(self.name)
+ m.add_string(self._signing_key.sign(data).signature)
+ return m
+
+ def verify_ssh_sig(self, data, msg):
+ if msg.get_text() != self.name:
+ return False
+
+ try:
+ self._verifying_key.verify(data, msg.get_binary())
+ except nacl.exceptions.BadSignatureError:
+ return False
+ else:
+ return True
diff --git a/paramiko/file.py b/paramiko/file.py
index 9ff61859..a36abb98 100644
--- a/paramiko/file.py
+++ b/paramiko/file.py
@@ -1,5 +1,30 @@
+# Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
from io import BytesIO
-from paramiko.common import linefeed_byte_value, crlf, cr_byte, linefeed_byte, cr_byte_value
+
+from paramiko.common import (
+ linefeed_byte_value,
+ crlf,
+ cr_byte,
+ linefeed_byte,
+ cr_byte_value,
+)
+
from paramiko.util import ClosingContextManager, u
@@ -8,17 +33,20 @@ class BufferedFile(ClosingContextManager):
Reusable base class to implement Python-style file buffering around a
simpler stream.
"""
+
_DEFAULT_BUFSIZE = 8192
+
SEEK_SET = 0
SEEK_CUR = 1
SEEK_END = 2
- FLAG_READ = 1
- FLAG_WRITE = 2
- FLAG_APPEND = 4
- FLAG_BINARY = 16
- FLAG_BUFFERED = 32
- FLAG_LINE_BUFFERED = 64
- FLAG_UNIVERSAL_NEWLINE = 128
+
+ FLAG_READ = 0x1
+ FLAG_WRITE = 0x2
+ FLAG_APPEND = 0x4
+ FLAG_BINARY = 0x10
+ FLAG_BUFFERED = 0x20
+ FLAG_LINE_BUFFERED = 0x40
+ FLAG_UNIVERSAL_NEWLINE = 0x80
def __init__(self):
self.newlines = None
@@ -28,7 +56,11 @@ class BufferedFile(ClosingContextManager):
self._rbuffer = bytes()
self._at_trailing_cr = False
self._closed = False
+ # pos - position within the file, according to the user
+ # realpos - position according the OS
+ # (these may be different because we buffer for line reading)
self._pos = self._realpos = 0
+ # size only matters for seekable files
self._size = 0
def __del__(self):
@@ -43,21 +75,24 @@ class BufferedFile(ClosingContextManager):
:raises: ``ValueError`` -- if the file is closed.
"""
if self._closed:
- raise ValueError('I/O operation on closed file')
+ raise ValueError("I/O operation on closed file")
return self
def close(self):
"""
Close the file. Future read and write operations will fail.
"""
- pass
+ self.flush()
+ self._closed = True
def flush(self):
"""
Write out any data in the write buffer. This may do nothing if write
buffering is not turned on.
"""
- pass
+ self._write_all(self._wbuffer.getvalue())
+ self._wbuffer = BytesIO()
+ return
def __next__(self):
"""
@@ -84,7 +119,7 @@ class BufferedFile(ClosingContextManager):
`True` if the file can be read from. If `False`, `read` will raise
an exception.
"""
- pass
+ return (self._flags & self.FLAG_READ) == self.FLAG_READ
def writable(self):
"""
@@ -94,7 +129,7 @@ class BufferedFile(ClosingContextManager):
`True` if the file can be written to. If `False`, `write` will
raise an exception.
"""
- pass
+ return (self._flags & self.FLAG_WRITE) == self.FLAG_WRITE
def seekable(self):
"""
@@ -104,7 +139,7 @@ class BufferedFile(ClosingContextManager):
`True` if the file supports random access. If `False`, `seek` will
raise an exception.
"""
- pass
+ return False
def readinto(self, buff):
"""
@@ -114,7 +149,9 @@ class BufferedFile(ClosingContextManager):
:returns:
The number of bytes read.
"""
- pass
+ data = self.read(len(buff))
+ buff[: len(data)] = data
+ return len(data)
def read(self, size=None):
"""
@@ -133,7 +170,47 @@ class BufferedFile(ClosingContextManager):
data read from the file (as bytes), or an empty string if EOF was
encountered immediately
"""
- pass
+ if self._closed:
+ raise IOError("File is closed")
+ if not (self._flags & self.FLAG_READ):
+ raise IOError("File is not open for reading")
+ if (size is None) or (size < 0):
+ # go for broke
+ result = bytearray(self._rbuffer)
+ self._rbuffer = bytes()
+ self._pos += len(result)
+ while True:
+ try:
+ new_data = self._read(self._DEFAULT_BUFSIZE)
+ except EOFError:
+ new_data = None
+ if (new_data is None) or (len(new_data) == 0):
+ break
+ result.extend(new_data)
+ self._realpos += len(new_data)
+ self._pos += len(new_data)
+ return bytes(result)
+ if size <= len(self._rbuffer):
+ result = self._rbuffer[:size]
+ self._rbuffer = self._rbuffer[size:]
+ self._pos += len(result)
+ return result
+ while len(self._rbuffer) < size:
+ read_size = size - len(self._rbuffer)
+ if self._flags & self.FLAG_BUFFERED:
+ read_size = max(self._bufsize, read_size)
+ try:
+ new_data = self._read(read_size)
+ except EOFError:
+ new_data = None
+ if (new_data is None) or (len(new_data) == 0):
+ break
+ self._rbuffer += new_data
+ self._realpos += len(new_data)
+ result = self._rbuffer[:size]
+ self._rbuffer = self._rbuffer[size:]
+ self._pos += len(result)
+ return result
def readline(self, size=None):
"""
@@ -157,7 +234,88 @@ class BufferedFile(ClosingContextManager):
Else: the encoding of the file is assumed to be UTF-8 and character
strings (`str`) are returned
"""
- pass
+ # it's almost silly how complex this function is.
+ if self._closed:
+ raise IOError("File is closed")
+ if not (self._flags & self.FLAG_READ):
+ raise IOError("File not open for reading")
+ line = self._rbuffer
+ truncated = False
+ while True:
+ if (
+ self._at_trailing_cr
+ and self._flags & self.FLAG_UNIVERSAL_NEWLINE
+ and len(line) > 0
+ ):
+ # edge case: the newline may be '\r\n' and we may have read
+ # only the first '\r' last time.
+ if line[0] == linefeed_byte_value:
+ line = line[1:]
+ self._record_newline(crlf)
+ else:
+ self._record_newline(cr_byte)
+ self._at_trailing_cr = False
+ # check size before looking for a linefeed, in case we already have
+ # enough.
+ if (size is not None) and (size >= 0):
+ if len(line) >= size:
+ # truncate line
+ self._rbuffer = line[size:]
+ line = line[:size]
+ truncated = True
+ break
+ n = size - len(line)
+ else:
+ n = self._bufsize
+ if linefeed_byte in line or (
+ self._flags & self.FLAG_UNIVERSAL_NEWLINE and cr_byte in line
+ ):
+ break
+ try:
+ new_data = self._read(n)
+ except EOFError:
+ new_data = None
+ if (new_data is None) or (len(new_data) == 0):
+ self._rbuffer = bytes()
+ self._pos += len(line)
+ return line if self._flags & self.FLAG_BINARY else u(line)
+ line += new_data
+ self._realpos += len(new_data)
+ # find the newline
+ pos = line.find(linefeed_byte)
+ if self._flags & self.FLAG_UNIVERSAL_NEWLINE:
+ rpos = line.find(cr_byte)
+ if (rpos >= 0) and (rpos < pos or pos < 0):
+ pos = rpos
+ if pos == -1:
+ # we couldn't find a newline in the truncated string, return it
+ self._pos += len(line)
+ return line if self._flags & self.FLAG_BINARY else u(line)
+ xpos = pos + 1
+ if (
+ line[pos] == cr_byte_value
+ and xpos < len(line)
+ and line[xpos] == linefeed_byte_value
+ ):
+ xpos += 1
+ # if the string was truncated, _rbuffer needs to have the string after
+ # the newline character plus the truncated part of the line we stored
+ # earlier in _rbuffer
+ if truncated:
+ self._rbuffer = line[xpos:] + self._rbuffer
+ else:
+ self._rbuffer = line[xpos:]
+
+ lf = line[pos:xpos]
+ line = line[:pos] + linefeed_byte
+ if (len(self._rbuffer) == 0) and (lf == cr_byte):
+ # we could read the line up to a '\r' and there could still be a
+ # '\n' following that we read next time. note that and eat it.
+ self._at_trailing_cr = True
+ else:
+ self._record_newline(lf)
+ self._pos += len(line)
+ return line if self._flags & self.FLAG_BINARY else u(line)
def readlines(self, sizehint=None):
"""
@@ -169,7 +327,17 @@ class BufferedFile(ClosingContextManager):
:param int sizehint: desired maximum number of bytes to read.
:returns: list of lines read from the file.
"""
- pass
+ lines = []
+ byte_count = 0
+ while True:
+ line = self.readline()
+ if len(line) == 0:
+ break
+ lines.append(line)
+ byte_count += len(line)
+ if (sizehint is not None) and (byte_count >= sizehint):
+ break
+ return lines
def seek(self, offset, whence=0):
"""
@@ -189,7 +357,7 @@ class BufferedFile(ClosingContextManager):
:raises: ``IOError`` -- if the file doesn't support random access.
"""
- pass
+ raise IOError("File does not support seeking.")
def tell(self):
"""
@@ -199,7 +367,7 @@ class BufferedFile(ClosingContextManager):
:returns: file position (`number <int>` of bytes).
"""
- pass
+ return self._pos
def write(self, data):
"""
@@ -210,7 +378,32 @@ class BufferedFile(ClosingContextManager):
:param data: ``str``/``bytes`` data to write
"""
- pass
+ if isinstance(data, str):
+ # Accept text and encode as utf-8 for compatibility only.
+ data = data.encode("utf-8")
+ if self._closed:
+ raise IOError("File is closed")
+ if not (self._flags & self.FLAG_WRITE):
+ raise IOError("File not open for writing")
+ if not (self._flags & self.FLAG_BUFFERED):
+ self._write_all(data)
+ return
+ self._wbuffer.write(data)
+ if self._flags & self.FLAG_LINE_BUFFERED:
+ # only scan the new data for linefeed, to avoid wasting time.
+ last_newline_pos = data.rfind(linefeed_byte)
+ if last_newline_pos >= 0:
+ wbuf = self._wbuffer.getvalue()
+ last_newline_pos += len(wbuf) - len(data)
+ self._write_all(wbuf[: last_newline_pos + 1])
+ self._wbuffer = BytesIO()
+ self._wbuffer.write(wbuf[last_newline_pos + 1 :])
+ return
+ # even if we're line buffering, if the buffer has grown past the
+ # buffer size, force a flush.
+ if self._wbuffer.tell() >= self._bufsize:
+ self.flush()
+ return
def writelines(self, sequence):
"""
@@ -221,14 +414,22 @@ class BufferedFile(ClosingContextManager):
:param sequence: an iterable sequence of strings.
"""
- pass
+ for line in sequence:
+ self.write(line)
+ return
def xreadlines(self):
"""
Identical to ``iter(f)``. This is a deprecated file interface that
predates Python iterator support.
"""
- pass
+ return self
+
+ @property
+ def closed(self):
+ return self._closed
+
+ # ...overrides...
def _read(self, size):
"""
@@ -236,14 +437,14 @@ class BufferedFile(ClosingContextManager):
Read data from the stream. Return ``None`` or raise ``EOFError`` to
indicate EOF.
"""
- pass
+ raise EOFError()
def _write(self, data):
"""
(subclass override)
Write data into the stream.
"""
- pass
+ raise IOError("write not implemented")
def _get_size(self):
"""
@@ -254,10 +455,74 @@ class BufferedFile(ClosingContextManager):
a stream that can't be randomly accessed, you don't need to override
this method,
"""
- pass
+ return 0
+
+ # ...internals...
- def _set_mode(self, mode='r', bufsize=-1):
+ def _set_mode(self, mode="r", bufsize=-1):
"""
Subclasses call this method to initialize the BufferedFile.
"""
- pass
+ # set bufsize in any event, because it's used for readline().
+ self._bufsize = self._DEFAULT_BUFSIZE
+ if bufsize < 0:
+ # do no buffering by default, because otherwise writes will get
+ # buffered in a way that will probably confuse people.
+ bufsize = 0
+ if bufsize == 1:
+ # apparently, line buffering only affects writes. reads are only
+ # buffered if you call readline (directly or indirectly: iterating
+ # over a file will indirectly call readline).
+ self._flags |= self.FLAG_BUFFERED | self.FLAG_LINE_BUFFERED
+ elif bufsize > 1:
+ self._bufsize = bufsize
+ self._flags |= self.FLAG_BUFFERED
+ self._flags &= ~self.FLAG_LINE_BUFFERED
+ elif bufsize == 0:
+ # unbuffered
+ self._flags &= ~(self.FLAG_BUFFERED | self.FLAG_LINE_BUFFERED)
+
+ if ("r" in mode) or ("+" in mode):
+ self._flags |= self.FLAG_READ
+ if ("w" in mode) or ("+" in mode):
+ self._flags |= self.FLAG_WRITE
+ if "a" in mode:
+ self._flags |= self.FLAG_WRITE | self.FLAG_APPEND
+ self._size = self._get_size()
+ self._pos = self._realpos = self._size
+ if "b" in mode:
+ self._flags |= self.FLAG_BINARY
+ if "U" in mode:
+ self._flags |= self.FLAG_UNIVERSAL_NEWLINE
+ # built-in file objects have this attribute to store which kinds of
+ # line terminations they've seen:
+ # <http://www.python.org/doc/current/lib/built-in-funcs.html>
+ self.newlines = None
+
+ def _write_all(self, raw_data):
+ # the underlying stream may be something that does partial writes (like
+ # a socket).
+ data = memoryview(raw_data)
+ while len(data) > 0:
+ count = self._write(data)
+ data = data[count:]
+ if self._flags & self.FLAG_APPEND:
+ self._size += count
+ self._pos = self._realpos = self._size
+ else:
+ self._pos += count
+ self._realpos += count
+ return None
+
+ def _record_newline(self, newline):
+ # silliness about tracking what kinds of newlines we've seen.
+ # i don't understand why it can be None, a string, or a tuple, instead
+ # of just always being a tuple, but we'll emulate that behavior anyway.
+ if not (self._flags & self.FLAG_UNIVERSAL_NEWLINE):
+ return
+ if self.newlines is None:
+ self.newlines = newline
+ elif self.newlines != newline and isinstance(self.newlines, bytes):
+ self.newlines = (self.newlines, newline)
+ elif newline not in self.newlines:
+ self.newlines += (newline,)
diff --git a/paramiko/hostkeys.py b/paramiko/hostkeys.py
index f2bbb85b..4d47e950 100644
--- a/paramiko/hostkeys.py
+++ b/paramiko/hostkeys.py
@@ -1,10 +1,32 @@
+# Copyright (C) 2006-2007 Robey Pointer <robeypointer@gmail.com>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
+
from base64 import encodebytes, decodebytes
import binascii
import os
import re
+
from collections.abc import MutableMapping
from hashlib import sha1
from hmac import HMAC
+
+
from paramiko.pkey import PKey, UnknownKeyType
from paramiko.util import get_logger, constant_time_bytes_eq, b, u
from paramiko.ssh_exception import SSHException
@@ -29,6 +51,7 @@ class HostKeys(MutableMapping):
:param str filename: filename to load host keys from, or ``None``
"""
+ # emulate a dict of { hostname: { keytype: PKey } }
self._entries = []
if filename is not None:
self.load(filename)
@@ -42,7 +65,11 @@ class HostKeys(MutableMapping):
:param str keytype: key type (``"ssh-rsa"`` or ``"ssh-dss"``)
:param .PKey key: the key to add
"""
- pass
+ for e in self._entries:
+ if (hostname in e.hostnames) and (e.key.get_name() == keytype):
+ e.key = key
+ return
+ self._entries.append(HostKeyEntry([hostname], key))
def load(self, filename):
"""
@@ -59,7 +86,22 @@ class HostKeys(MutableMapping):
:raises: ``IOError`` -- if there was an error reading the file
"""
- pass
+ with open(filename, "r") as f:
+ for lineno, line in enumerate(f, 1):
+ line = line.strip()
+ if (len(line) == 0) or (line[0] == "#"):
+ continue
+ try:
+ entry = HostKeyEntry.from_line(line, lineno)
+ except SSHException:
+ continue
+ if entry is not None:
+ _hostnames = entry.hostnames
+ for h in _hostnames:
+ if self.check(h, entry.key):
+ entry.hostnames.remove(h)
+ if len(entry.hostnames):
+ self._entries.append(entry)
def save(self, filename):
"""
@@ -74,7 +116,11 @@ class HostKeys(MutableMapping):
.. versionadded:: 1.6.1
"""
- pass
+ with open(filename, "w") as f:
+ for e in self._entries:
+ line = e.to_line()
+ if line:
+ f.write(line)
def lookup(self, hostname):
"""
@@ -86,7 +132,62 @@ class HostKeys(MutableMapping):
:return: dict of `str` -> `.PKey` keys associated with this host
(or ``None``)
"""
- pass
+
+ class SubDict(MutableMapping):
+ def __init__(self, hostname, entries, hostkeys):
+ self._hostname = hostname
+ self._entries = entries
+ self._hostkeys = hostkeys
+
+ def __iter__(self):
+ for k in self.keys():
+ yield k
+
+ def __len__(self):
+ return len(self.keys())
+
+ def __delitem__(self, key):
+ for e in list(self._entries):
+ if e.key.get_name() == key:
+ self._entries.remove(e)
+ break
+ else:
+ raise KeyError(key)
+
+ def __getitem__(self, key):
+ for e in self._entries:
+ if e.key.get_name() == key:
+ return e.key
+ raise KeyError(key)
+
+ def __setitem__(self, key, val):
+ for e in self._entries:
+ if e.key is None:
+ continue
+ if e.key.get_name() == key:
+ # replace
+ e.key = val
+ break
+ else:
+ # add a new one
+ e = HostKeyEntry([hostname], val)
+ self._entries.append(e)
+ self._hostkeys._entries.append(e)
+
+ def keys(self):
+ return [
+ e.key.get_name()
+ for e in self._entries
+ if e.key is not None
+ ]
+
+ entries = []
+ for e in self._entries:
+ if self._hostname_matches(hostname, e):
+ entries.append(e)
+ if len(entries) == 0:
+ return None
+ return SubDict(hostname, entries, self)
def _hostname_matches(self, hostname, entry):
"""
@@ -94,7 +195,15 @@ class HostKeys(MutableMapping):
:returns bool:
"""
- pass
+ for h in entry.hostnames:
+ if (
+ h == hostname
+ or h.startswith("|1|")
+ and not hostname.startswith("|1|")
+ and constant_time_bytes_eq(self.hash_host(hostname, h), h)
+ ):
+ return True
+ return False
def check(self, hostname, key):
"""
@@ -106,13 +215,19 @@ class HostKeys(MutableMapping):
:return:
``True`` if the key is associated with the hostname; else ``False``
"""
- pass
+ k = self.lookup(hostname)
+ if k is None:
+ return False
+ host_key = k.get(key.get_name(), None)
+ if host_key is None:
+ return False
+ return host_key.asbytes() == key.asbytes()
def clear(self):
"""
Remove all host keys from the dictionary.
"""
- pass
+ self._entries = []
def __iter__(self):
for k in self.keys():
@@ -138,18 +253,34 @@ class HostKeys(MutableMapping):
self._entries.pop(index)
def __setitem__(self, hostname, entry):
+ # don't use this please.
if len(entry) == 0:
self._entries.append(HostKeyEntry([hostname], None))
return
for key_type in entry.keys():
found = False
for e in self._entries:
- if hostname in e.hostnames and e.key.get_name() == key_type:
+ if (hostname in e.hostnames) and e.key.get_name() == key_type:
+ # replace
e.key = entry[key_type]
found = True
if not found:
self._entries.append(HostKeyEntry([hostname], entry[key_type]))
+ def keys(self):
+ ret = []
+ for e in self._entries:
+ for h in e.hostnames:
+ if h not in ret:
+ ret.append(h)
+ return ret
+
+ def values(self):
+ ret = []
+ for k in self.keys():
+ ret.append(self.lookup(k))
+ return ret
+
@staticmethod
def hash_host(hostname, salt=None):
"""
@@ -161,15 +292,23 @@ class HostKeys(MutableMapping):
(must be 20 bytes long)
:return: the hashed hostname as a `str`
"""
- pass
+ if salt is None:
+ salt = os.urandom(sha1().digest_size)
+ else:
+ if salt.startswith("|1|"):
+ salt = salt.split("|")[2]
+ salt = decodebytes(b(salt))
+ assert len(salt) == sha1().digest_size
+ hmac = HMAC(salt, b(hostname), sha1).digest()
+ hostkey = "|1|{}|{}".format(u(encodebytes(salt)), u(encodebytes(hmac)))
+ return hostkey.replace("\n", "")
class InvalidHostKey(Exception):
-
def __init__(self, line, exc):
self.line = line
self.exc = exc
- self.args = line, exc
+ self.args = (line, exc)
class HostKeyEntry:
@@ -178,7 +317,7 @@ class HostKeyEntry:
"""
def __init__(self, hostnames=None, key=None):
- self.valid = hostnames is not None and key is not None
+ self.valid = (hostnames is not None) and (key is not None)
self.hostnames = hostnames
self.key = key
@@ -196,7 +335,36 @@ class HostKeyEntry:
:param str line: a line from an OpenSSH known_hosts file
"""
- pass
+ log = get_logger("paramiko.hostkeys")
+ fields = re.split(" |\t", line)
+ if len(fields) < 3:
+ # Bad number of fields
+ msg = "Not enough fields found in known_hosts in line {} ({!r})"
+ log.info(msg.format(lineno, line))
+ return None
+ fields = fields[:3]
+
+ names, key_type, key = fields
+ names = names.split(",")
+
+ # Decide what kind of key we're looking at and create an object
+ # to hold it accordingly.
+ try:
+ # TODO: this grew organically and doesn't seem /wrong/ per se (file
+ # read -> unicode str -> bytes for base64 decode -> decoded bytes);
+ # but in Python 3 forever land, can we simply use
+ # `base64.b64decode(str-from-file)` here?
+ key_bytes = decodebytes(b(key))
+ except binascii.Error as e:
+ raise InvalidHostKey(line, e)
+
+ try:
+ return cls(names, PKey.from_type_string(key_type, key_bytes))
+ except UnknownKeyType:
+ # TODO 4.0: consider changing HostKeys API so this just raises
+ # naturally and the exception is muted higher up in the stack?
+ log.info("Unable to handle key of type {}".format(key_type))
+ return None
def to_line(self):
"""
@@ -204,7 +372,13 @@ class HostKeyEntry:
the object is not in a valid state. A trailing newline is
included.
"""
- pass
+ if self.valid:
+ return "{} {} {}\n".format(
+ ",".join(self.hostnames),
+ self.key.get_name(),
+ self.key.get_base64(),
+ )
+ return None
def __repr__(self):
- return '<HostKeyEntry {!r}: {!r}>'.format(self.hostnames, self.key)
+ return "<HostKeyEntry {!r}: {!r}>".format(self.hostnames, self.key)
diff --git a/paramiko/kex_curve25519.py b/paramiko/kex_curve25519.py
index 1d58fd85..20c23e42 100644
--- a/paramiko/kex_curve25519.py
+++ b/paramiko/kex_curve25519.py
@@ -1,11 +1,18 @@
import binascii
import hashlib
+
from cryptography.exceptions import UnsupportedAlgorithm
from cryptography.hazmat.primitives import constant_time, serialization
-from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey
+from cryptography.hazmat.primitives.asymmetric.x25519 import (
+ X25519PrivateKey,
+ X25519PublicKey,
+)
+
from paramiko.message import Message
from paramiko.common import byte_chr
from paramiko.ssh_exception import SSHException
+
+
_MSG_KEXECDH_INIT, _MSG_KEXECDH_REPLY = range(30, 32)
c_MSG_KEXECDH_INIT, c_MSG_KEXECDH_REPLY = [byte_chr(c) for c in range(30, 32)]
@@ -16,3 +23,109 @@ class KexCurve25519:
def __init__(self, transport):
self.transport = transport
self.key = None
+
+ @classmethod
+ def is_available(cls):
+ try:
+ X25519PrivateKey.generate()
+ except UnsupportedAlgorithm:
+ return False
+ else:
+ return True
+
+ def _perform_exchange(self, peer_key):
+ secret = self.key.exchange(peer_key)
+ if constant_time.bytes_eq(secret, b"\x00" * 32):
+ raise SSHException(
+ "peer's curve25519 public value has wrong order"
+ )
+ return secret
+
+ def start_kex(self):
+ self.key = X25519PrivateKey.generate()
+ if self.transport.server_mode:
+ self.transport._expect_packet(_MSG_KEXECDH_INIT)
+ return
+
+ m = Message()
+ m.add_byte(c_MSG_KEXECDH_INIT)
+ m.add_string(
+ self.key.public_key().public_bytes(
+ serialization.Encoding.Raw, serialization.PublicFormat.Raw
+ )
+ )
+ self.transport._send_message(m)
+ self.transport._expect_packet(_MSG_KEXECDH_REPLY)
+
+ def parse_next(self, ptype, m):
+ if self.transport.server_mode and (ptype == _MSG_KEXECDH_INIT):
+ return self._parse_kexecdh_init(m)
+ elif not self.transport.server_mode and (ptype == _MSG_KEXECDH_REPLY):
+ return self._parse_kexecdh_reply(m)
+ raise SSHException(
+ "KexCurve25519 asked to handle packet type {:d}".format(ptype)
+ )
+
+ def _parse_kexecdh_init(self, m):
+ peer_key_bytes = m.get_string()
+ peer_key = X25519PublicKey.from_public_bytes(peer_key_bytes)
+ K = self._perform_exchange(peer_key)
+ K = int(binascii.hexlify(K), 16)
+ # compute exchange hash
+ hm = Message()
+ hm.add(
+ self.transport.remote_version,
+ self.transport.local_version,
+ self.transport.remote_kex_init,
+ self.transport.local_kex_init,
+ )
+ server_key_bytes = self.transport.get_server_key().asbytes()
+ exchange_key_bytes = self.key.public_key().public_bytes(
+ serialization.Encoding.Raw, serialization.PublicFormat.Raw
+ )
+ hm.add_string(server_key_bytes)
+ hm.add_string(peer_key_bytes)
+ hm.add_string(exchange_key_bytes)
+ hm.add_mpint(K)
+ H = self.hash_algo(hm.asbytes()).digest()
+ self.transport._set_K_H(K, H)
+ sig = self.transport.get_server_key().sign_ssh_data(
+ H, self.transport.host_key_type
+ )
+ # construct reply
+ m = Message()
+ m.add_byte(c_MSG_KEXECDH_REPLY)
+ m.add_string(server_key_bytes)
+ m.add_string(exchange_key_bytes)
+ m.add_string(sig)
+ self.transport._send_message(m)
+ self.transport._activate_outbound()
+
+ def _parse_kexecdh_reply(self, m):
+ peer_host_key_bytes = m.get_string()
+ peer_key_bytes = m.get_string()
+ sig = m.get_binary()
+
+ peer_key = X25519PublicKey.from_public_bytes(peer_key_bytes)
+
+ K = self._perform_exchange(peer_key)
+ K = int(binascii.hexlify(K), 16)
+ # compute exchange hash and verify signature
+ hm = Message()
+ hm.add(
+ self.transport.local_version,
+ self.transport.remote_version,
+ self.transport.local_kex_init,
+ self.transport.remote_kex_init,
+ )
+ hm.add_string(peer_host_key_bytes)
+ hm.add_string(
+ self.key.public_key().public_bytes(
+ serialization.Encoding.Raw, serialization.PublicFormat.Raw
+ )
+ )
+ hm.add_string(peer_key_bytes)
+ hm.add_mpint(K)
+ self.transport._set_K_H(K, self.hash_algo(hm.asbytes()).digest())
+ self.transport._verify_key(peer_host_key_bytes, sig)
+ self.transport._activate_outbound()
diff --git a/paramiko/kex_ecdh_nist.py b/paramiko/kex_ecdh_nist.py
index b119304b..41fab46b 100644
--- a/paramiko/kex_ecdh_nist.py
+++ b/paramiko/kex_ecdh_nist.py
@@ -2,6 +2,7 @@
Ephemeral Elliptic Curve Diffie-Hellman (ECDH) key exchange
RFC 5656, Section 4
"""
+
from hashlib import sha256, sha384, sha512
from paramiko.common import byte_chr
from paramiko.message import Message
@@ -10,29 +11,141 @@ from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives import serialization
from binascii import hexlify
+
_MSG_KEXECDH_INIT, _MSG_KEXECDH_REPLY = range(30, 32)
c_MSG_KEXECDH_INIT, c_MSG_KEXECDH_REPLY = [byte_chr(c) for c in range(30, 32)]
class KexNistp256:
- name = 'ecdh-sha2-nistp256'
+
+ name = "ecdh-sha2-nistp256"
hash_algo = sha256
curve = ec.SECP256R1()
def __init__(self, transport):
self.transport = transport
+ # private key, client public and server public keys
self.P = 0
self.Q_C = None
self.Q_S = None
+ def start_kex(self):
+ self._generate_key_pair()
+ if self.transport.server_mode:
+ self.transport._expect_packet(_MSG_KEXECDH_INIT)
+ return
+ m = Message()
+ m.add_byte(c_MSG_KEXECDH_INIT)
+ # SEC1: V2.0 2.3.3 Elliptic-Curve-Point-to-Octet-String Conversion
+ m.add_string(
+ self.Q_C.public_bytes(
+ serialization.Encoding.X962,
+ serialization.PublicFormat.UncompressedPoint,
+ )
+ )
+ self.transport._send_message(m)
+ self.transport._expect_packet(_MSG_KEXECDH_REPLY)
+
+ def parse_next(self, ptype, m):
+ if self.transport.server_mode and (ptype == _MSG_KEXECDH_INIT):
+ return self._parse_kexecdh_init(m)
+ elif not self.transport.server_mode and (ptype == _MSG_KEXECDH_REPLY):
+ return self._parse_kexecdh_reply(m)
+ raise SSHException(
+ "KexECDH asked to handle packet type {:d}".format(ptype)
+ )
+
+ def _generate_key_pair(self):
+ self.P = ec.generate_private_key(self.curve, default_backend())
+ if self.transport.server_mode:
+ self.Q_S = self.P.public_key()
+ return
+ self.Q_C = self.P.public_key()
+
+ def _parse_kexecdh_init(self, m):
+ Q_C_bytes = m.get_string()
+ self.Q_C = ec.EllipticCurvePublicKey.from_encoded_point(
+ self.curve, Q_C_bytes
+ )
+ K_S = self.transport.get_server_key().asbytes()
+ K = self.P.exchange(ec.ECDH(), self.Q_C)
+ K = int(hexlify(K), 16)
+ # compute exchange hash
+ hm = Message()
+ hm.add(
+ self.transport.remote_version,
+ self.transport.local_version,
+ self.transport.remote_kex_init,
+ self.transport.local_kex_init,
+ )
+ hm.add_string(K_S)
+ hm.add_string(Q_C_bytes)
+ # SEC1: V2.0 2.3.3 Elliptic-Curve-Point-to-Octet-String Conversion
+ hm.add_string(
+ self.Q_S.public_bytes(
+ serialization.Encoding.X962,
+ serialization.PublicFormat.UncompressedPoint,
+ )
+ )
+ hm.add_mpint(int(K))
+ H = self.hash_algo(hm.asbytes()).digest()
+ self.transport._set_K_H(K, H)
+ sig = self.transport.get_server_key().sign_ssh_data(
+ H, self.transport.host_key_type
+ )
+ # construct reply
+ m = Message()
+ m.add_byte(c_MSG_KEXECDH_REPLY)
+ m.add_string(K_S)
+ m.add_string(
+ self.Q_S.public_bytes(
+ serialization.Encoding.X962,
+ serialization.PublicFormat.UncompressedPoint,
+ )
+ )
+ m.add_string(sig)
+ self.transport._send_message(m)
+ self.transport._activate_outbound()
+
+ def _parse_kexecdh_reply(self, m):
+ K_S = m.get_string()
+ Q_S_bytes = m.get_string()
+ self.Q_S = ec.EllipticCurvePublicKey.from_encoded_point(
+ self.curve, Q_S_bytes
+ )
+ sig = m.get_binary()
+ K = self.P.exchange(ec.ECDH(), self.Q_S)
+ K = int(hexlify(K), 16)
+ # compute exchange hash and verify signature
+ hm = Message()
+ hm.add(
+ self.transport.local_version,
+ self.transport.remote_version,
+ self.transport.local_kex_init,
+ self.transport.remote_kex_init,
+ )
+ hm.add_string(K_S)
+ # SEC1: V2.0 2.3.3 Elliptic-Curve-Point-to-Octet-String Conversion
+ hm.add_string(
+ self.Q_C.public_bytes(
+ serialization.Encoding.X962,
+ serialization.PublicFormat.UncompressedPoint,
+ )
+ )
+ hm.add_string(Q_S_bytes)
+ hm.add_mpint(K)
+ self.transport._set_K_H(K, self.hash_algo(hm.asbytes()).digest())
+ self.transport._verify_key(K_S, sig)
+ self.transport._activate_outbound()
+
class KexNistp384(KexNistp256):
- name = 'ecdh-sha2-nistp384'
+ name = "ecdh-sha2-nistp384"
hash_algo = sha384
curve = ec.SECP384R1()
class KexNistp521(KexNistp256):
- name = 'ecdh-sha2-nistp521'
+ name = "ecdh-sha2-nistp521"
hash_algo = sha512
curve = ec.SECP521R1()
diff --git a/paramiko/kex_gex.py b/paramiko/kex_gex.py
index c7455ab2..baa0803d 100644
--- a/paramiko/kex_gex.py
+++ b/paramiko/kex_gex.py
@@ -1,23 +1,56 @@
+# Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
"""
Variant on `KexGroup1 <paramiko.kex_group1.KexGroup1>` where the prime "p" and
generator "g" are provided by the server. A bit more work is required on the
client side, and a **lot** more on the server side.
"""
+
import os
from hashlib import sha1, sha256
+
from paramiko import util
from paramiko.common import DEBUG, byte_chr, byte_ord, byte_mask
from paramiko.message import Message
from paramiko.ssh_exception import SSHException
-(_MSG_KEXDH_GEX_REQUEST_OLD, _MSG_KEXDH_GEX_GROUP, _MSG_KEXDH_GEX_INIT,
- _MSG_KEXDH_GEX_REPLY, _MSG_KEXDH_GEX_REQUEST) = range(30, 35)
-(c_MSG_KEXDH_GEX_REQUEST_OLD, c_MSG_KEXDH_GEX_GROUP, c_MSG_KEXDH_GEX_INIT,
- c_MSG_KEXDH_GEX_REPLY, c_MSG_KEXDH_GEX_REQUEST) = [byte_chr(c) for c in
- range(30, 35)]
+
+
+(
+ _MSG_KEXDH_GEX_REQUEST_OLD,
+ _MSG_KEXDH_GEX_GROUP,
+ _MSG_KEXDH_GEX_INIT,
+ _MSG_KEXDH_GEX_REPLY,
+ _MSG_KEXDH_GEX_REQUEST,
+) = range(30, 35)
+
+(
+ c_MSG_KEXDH_GEX_REQUEST_OLD,
+ c_MSG_KEXDH_GEX_GROUP,
+ c_MSG_KEXDH_GEX_INIT,
+ c_MSG_KEXDH_GEX_REPLY,
+ c_MSG_KEXDH_GEX_REQUEST,
+) = [byte_chr(c) for c in range(30, 35)]
class KexGex:
- name = 'diffie-hellman-group-exchange-sha1'
+
+ name = "diffie-hellman-group-exchange-sha1"
min_bits = 1024
max_bits = 8192
preferred_bits = 2048
@@ -33,7 +66,223 @@ class KexGex:
self.f = None
self.old_style = False
+ def start_kex(self, _test_old_style=False):
+ if self.transport.server_mode:
+ self.transport._expect_packet(
+ _MSG_KEXDH_GEX_REQUEST, _MSG_KEXDH_GEX_REQUEST_OLD
+ )
+ return
+ # request a bit range: we accept (min_bits) to (max_bits), but prefer
+ # (preferred_bits). according to the spec, we shouldn't pull the
+ # minimum up above 1024.
+ m = Message()
+ if _test_old_style:
+ # only used for unit tests: we shouldn't ever send this
+ m.add_byte(c_MSG_KEXDH_GEX_REQUEST_OLD)
+ m.add_int(self.preferred_bits)
+ self.old_style = True
+ else:
+ m.add_byte(c_MSG_KEXDH_GEX_REQUEST)
+ m.add_int(self.min_bits)
+ m.add_int(self.preferred_bits)
+ m.add_int(self.max_bits)
+ self.transport._send_message(m)
+ self.transport._expect_packet(_MSG_KEXDH_GEX_GROUP)
+
+ def parse_next(self, ptype, m):
+ if ptype == _MSG_KEXDH_GEX_REQUEST:
+ return self._parse_kexdh_gex_request(m)
+ elif ptype == _MSG_KEXDH_GEX_GROUP:
+ return self._parse_kexdh_gex_group(m)
+ elif ptype == _MSG_KEXDH_GEX_INIT:
+ return self._parse_kexdh_gex_init(m)
+ elif ptype == _MSG_KEXDH_GEX_REPLY:
+ return self._parse_kexdh_gex_reply(m)
+ elif ptype == _MSG_KEXDH_GEX_REQUEST_OLD:
+ return self._parse_kexdh_gex_request_old(m)
+ msg = "KexGex {} asked to handle packet type {:d}"
+ raise SSHException(msg.format(self.name, ptype))
+
+ # ...internals...
+
+ def _generate_x(self):
+ # generate an "x" (1 < x < (p-1)/2).
+ q = (self.p - 1) // 2
+ qnorm = util.deflate_long(q, 0)
+ qhbyte = byte_ord(qnorm[0])
+ byte_count = len(qnorm)
+ qmask = 0xFF
+ while not (qhbyte & 0x80):
+ qhbyte <<= 1
+ qmask >>= 1
+ while True:
+ x_bytes = os.urandom(byte_count)
+ x_bytes = byte_mask(x_bytes[0], qmask) + x_bytes[1:]
+ x = util.inflate_long(x_bytes, 1)
+ if (x > 1) and (x < q):
+ break
+ self.x = x
+
+ def _parse_kexdh_gex_request(self, m):
+ minbits = m.get_int()
+ preferredbits = m.get_int()
+ maxbits = m.get_int()
+ # smoosh the user's preferred size into our own limits
+ if preferredbits > self.max_bits:
+ preferredbits = self.max_bits
+ if preferredbits < self.min_bits:
+ preferredbits = self.min_bits
+ # fix min/max if they're inconsistent. technically, we could just pout
+ # and hang up, but there's no harm in giving them the benefit of the
+ # doubt and just picking a bitsize for them.
+ if minbits > preferredbits:
+ minbits = preferredbits
+ if maxbits < preferredbits:
+ maxbits = preferredbits
+ # now save a copy
+ self.min_bits = minbits
+ self.preferred_bits = preferredbits
+ self.max_bits = maxbits
+ # generate prime
+ pack = self.transport._get_modulus_pack()
+ if pack is None:
+ raise SSHException("Can't do server-side gex with no modulus pack")
+ self.transport._log(
+ DEBUG,
+ "Picking p ({} <= {} <= {} bits)".format(
+ minbits, preferredbits, maxbits
+ ),
+ )
+ self.g, self.p = pack.get_modulus(minbits, preferredbits, maxbits)
+ m = Message()
+ m.add_byte(c_MSG_KEXDH_GEX_GROUP)
+ m.add_mpint(self.p)
+ m.add_mpint(self.g)
+ self.transport._send_message(m)
+ self.transport._expect_packet(_MSG_KEXDH_GEX_INIT)
+
+ def _parse_kexdh_gex_request_old(self, m):
+ # same as above, but without min_bits or max_bits (used by older
+ # clients like putty)
+ self.preferred_bits = m.get_int()
+ # smoosh the user's preferred size into our own limits
+ if self.preferred_bits > self.max_bits:
+ self.preferred_bits = self.max_bits
+ if self.preferred_bits < self.min_bits:
+ self.preferred_bits = self.min_bits
+ # generate prime
+ pack = self.transport._get_modulus_pack()
+ if pack is None:
+ raise SSHException("Can't do server-side gex with no modulus pack")
+ self.transport._log(
+ DEBUG, "Picking p (~ {} bits)".format(self.preferred_bits)
+ )
+ self.g, self.p = pack.get_modulus(
+ self.min_bits, self.preferred_bits, self.max_bits
+ )
+ m = Message()
+ m.add_byte(c_MSG_KEXDH_GEX_GROUP)
+ m.add_mpint(self.p)
+ m.add_mpint(self.g)
+ self.transport._send_message(m)
+ self.transport._expect_packet(_MSG_KEXDH_GEX_INIT)
+ self.old_style = True
+
+ def _parse_kexdh_gex_group(self, m):
+ self.p = m.get_mpint()
+ self.g = m.get_mpint()
+ # reject if p's bit length < 1024 or > 8192
+ bitlen = util.bit_length(self.p)
+ if (bitlen < 1024) or (bitlen > 8192):
+ raise SSHException(
+ "Server-generated gex p (don't ask) is out of range "
+ "({} bits)".format(bitlen)
+ )
+ self.transport._log(DEBUG, "Got server p ({} bits)".format(bitlen))
+ self._generate_x()
+ # now compute e = g^x mod p
+ self.e = pow(self.g, self.x, self.p)
+ m = Message()
+ m.add_byte(c_MSG_KEXDH_GEX_INIT)
+ m.add_mpint(self.e)
+ self.transport._send_message(m)
+ self.transport._expect_packet(_MSG_KEXDH_GEX_REPLY)
+
+ def _parse_kexdh_gex_init(self, m):
+ self.e = m.get_mpint()
+ if (self.e < 1) or (self.e > self.p - 1):
+ raise SSHException('Client kex "e" is out of range')
+ self._generate_x()
+ self.f = pow(self.g, self.x, self.p)
+ K = pow(self.e, self.x, self.p)
+ key = self.transport.get_server_key().asbytes()
+ # okay, build up the hash H of
+ # (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K) # noqa
+ hm = Message()
+ hm.add(
+ self.transport.remote_version,
+ self.transport.local_version,
+ self.transport.remote_kex_init,
+ self.transport.local_kex_init,
+ key,
+ )
+ if not self.old_style:
+ hm.add_int(self.min_bits)
+ hm.add_int(self.preferred_bits)
+ if not self.old_style:
+ hm.add_int(self.max_bits)
+ hm.add_mpint(self.p)
+ hm.add_mpint(self.g)
+ hm.add_mpint(self.e)
+ hm.add_mpint(self.f)
+ hm.add_mpint(K)
+ H = self.hash_algo(hm.asbytes()).digest()
+ self.transport._set_K_H(K, H)
+ # sign it
+ sig = self.transport.get_server_key().sign_ssh_data(
+ H, self.transport.host_key_type
+ )
+ # send reply
+ m = Message()
+ m.add_byte(c_MSG_KEXDH_GEX_REPLY)
+ m.add_string(key)
+ m.add_mpint(self.f)
+ m.add_string(sig)
+ self.transport._send_message(m)
+ self.transport._activate_outbound()
+
+ def _parse_kexdh_gex_reply(self, m):
+ host_key = m.get_string()
+ self.f = m.get_mpint()
+ sig = m.get_string()
+ if (self.f < 1) or (self.f > self.p - 1):
+ raise SSHException('Server kex "f" is out of range')
+ K = pow(self.f, self.x, self.p)
+ # okay, build up the hash H of
+ # (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K) # noqa
+ hm = Message()
+ hm.add(
+ self.transport.local_version,
+ self.transport.remote_version,
+ self.transport.local_kex_init,
+ self.transport.remote_kex_init,
+ host_key,
+ )
+ if not self.old_style:
+ hm.add_int(self.min_bits)
+ hm.add_int(self.preferred_bits)
+ if not self.old_style:
+ hm.add_int(self.max_bits)
+ hm.add_mpint(self.p)
+ hm.add_mpint(self.g)
+ hm.add_mpint(self.e)
+ hm.add_mpint(self.f)
+ hm.add_mpint(K)
+ self.transport._set_K_H(K, self.hash_algo(hm.asbytes()).digest())
+ self.transport._verify_key(host_key, sig)
+ self.transport._activate_outbound()
+
class KexGexSHA256(KexGex):
- name = 'diffie-hellman-group-exchange-sha256'
+ name = "diffie-hellman-group-exchange-sha256"
hash_algo = sha256
diff --git a/paramiko/kex_group1.py b/paramiko/kex_group1.py
index cb6aa737..f0742566 100644
--- a/paramiko/kex_group1.py
+++ b/paramiko/kex_group1.py
@@ -1,25 +1,49 @@
+# Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
"""
Standard SSH key exchange ("kex" if you wanna sound cool). Diffie-Hellman of
1024 bit key halves, using a known "p" prime and "g" generator.
"""
+
import os
from hashlib import sha1
+
from paramiko import util
from paramiko.common import max_byte, zero_byte, byte_chr, byte_mask
from paramiko.message import Message
from paramiko.ssh_exception import SSHException
+
+
_MSG_KEXDH_INIT, _MSG_KEXDH_REPLY = range(30, 32)
c_MSG_KEXDH_INIT, c_MSG_KEXDH_REPLY = [byte_chr(c) for c in range(30, 32)]
-b7fffffffffffffff = byte_chr(127) + max_byte * 7
+
+b7fffffffffffffff = byte_chr(0x7F) + max_byte * 7
b0000000000000000 = zero_byte * 8
class KexGroup1:
- P = (
- 179769313486231590770839156793787453197860296048756011706444423684197180216158519368947833795864925541502180565485980503646440548199239100050792877003355816639229553136239076508735759914822574862575007425302077447712589550957937778424442426617334727629299387668709205606050270810842907692932019128194467627007
- )
+
+ # draft-ietf-secsh-transport-09.txt, page 17
+ P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF # noqa
G = 2
- name = 'diffie-hellman-group1-sha1'
+
+ name = "diffie-hellman-group1-sha1"
hash_algo = sha1
def __init__(self, transport):
@@ -27,3 +51,105 @@ class KexGroup1:
self.x = 0
self.e = 0
self.f = 0
+
+ def start_kex(self):
+ self._generate_x()
+ if self.transport.server_mode:
+ # compute f = g^x mod p, but don't send it yet
+ self.f = pow(self.G, self.x, self.P)
+ self.transport._expect_packet(_MSG_KEXDH_INIT)
+ return
+ # compute e = g^x mod p (where g=2), and send it
+ self.e = pow(self.G, self.x, self.P)
+ m = Message()
+ m.add_byte(c_MSG_KEXDH_INIT)
+ m.add_mpint(self.e)
+ self.transport._send_message(m)
+ self.transport._expect_packet(_MSG_KEXDH_REPLY)
+
+ def parse_next(self, ptype, m):
+ if self.transport.server_mode and (ptype == _MSG_KEXDH_INIT):
+ return self._parse_kexdh_init(m)
+ elif not self.transport.server_mode and (ptype == _MSG_KEXDH_REPLY):
+ return self._parse_kexdh_reply(m)
+ msg = "KexGroup1 asked to handle packet type {:d}"
+ raise SSHException(msg.format(ptype))
+
+ # ...internals...
+
+ def _generate_x(self):
+ # generate an "x" (1 < x < q), where q is (p-1)/2.
+ # p is a 128-byte (1024-bit) number, where the first 64 bits are 1.
+ # therefore q can be approximated as a 2^1023. we drop the subset of
+ # potential x where the first 63 bits are 1, because some of those
+ # will be larger than q (but this is a tiny tiny subset of
+ # potential x).
+ while 1:
+ x_bytes = os.urandom(128)
+ x_bytes = byte_mask(x_bytes[0], 0x7F) + x_bytes[1:]
+ if (
+ x_bytes[:8] != b7fffffffffffffff
+ and x_bytes[:8] != b0000000000000000
+ ):
+ break
+ self.x = util.inflate_long(x_bytes)
+
+ def _parse_kexdh_reply(self, m):
+ # client mode
+ host_key = m.get_string()
+ self.f = m.get_mpint()
+ if (self.f < 1) or (self.f > self.P - 1):
+ raise SSHException('Server kex "f" is out of range')
+ sig = m.get_binary()
+ K = pow(self.f, self.x, self.P)
+ # okay, build up the hash H of
+ # (V_C || V_S || I_C || I_S || K_S || e || f || K)
+ hm = Message()
+ hm.add(
+ self.transport.local_version,
+ self.transport.remote_version,
+ self.transport.local_kex_init,
+ self.transport.remote_kex_init,
+ )
+ hm.add_string(host_key)
+ hm.add_mpint(self.e)
+ hm.add_mpint(self.f)
+ hm.add_mpint(K)
+ self.transport._set_K_H(K, self.hash_algo(hm.asbytes()).digest())
+ self.transport._verify_key(host_key, sig)
+ self.transport._activate_outbound()
+
+ def _parse_kexdh_init(self, m):
+ # server mode
+ self.e = m.get_mpint()
+ if (self.e < 1) or (self.e > self.P - 1):
+ raise SSHException('Client kex "e" is out of range')
+ K = pow(self.e, self.x, self.P)
+ key = self.transport.get_server_key().asbytes()
+ # okay, build up the hash H of
+ # (V_C || V_S || I_C || I_S || K_S || e || f || K)
+ hm = Message()
+ hm.add(
+ self.transport.remote_version,
+ self.transport.local_version,
+ self.transport.remote_kex_init,
+ self.transport.local_kex_init,
+ )
+ hm.add_string(key)
+ hm.add_mpint(self.e)
+ hm.add_mpint(self.f)
+ hm.add_mpint(K)
+ H = self.hash_algo(hm.asbytes()).digest()
+ self.transport._set_K_H(K, H)
+ # sign it
+ sig = self.transport.get_server_key().sign_ssh_data(
+ H, self.transport.host_key_type
+ )
+ # send reply
+ m = Message()
+ m.add_byte(c_MSG_KEXDH_REPLY)
+ m.add_string(key)
+ m.add_mpint(self.f)
+ m.add_string(sig)
+ self.transport._send_message(m)
+ self.transport._activate_outbound()
diff --git a/paramiko/kex_group14.py b/paramiko/kex_group14.py
index 91525869..8dee5515 100644
--- a/paramiko/kex_group14.py
+++ b/paramiko/kex_group14.py
@@ -1,20 +1,40 @@
+# Copyright (C) 2013 Torsten Landschoff <torsten@debian.org>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
"""
Standard SSH key exchange ("kex" if you wanna sound cool). Diffie-Hellman of
2048 bit key halves, using a known "p" prime and "g" generator.
"""
+
from paramiko.kex_group1 import KexGroup1
from hashlib import sha1, sha256
class KexGroup14(KexGroup1):
- P = (
- 32317006071311007300338913926423828248817941241140239112842009751400741706634354222619689417363569347117901737909704191754605873209195028853758986185622153212175412514901774520270235796078236248884246189477587641105928646099411723245426622522193230540919037680524235519125679715870117001058055877651038861847280257976054903569732561526167081339361799541336476559160368317896729073178384589680639671900977202194168647225871031411336429319536193471636533209717077448227988588565369208645296636077250268955505928362751121174096972998068410554359584866583291642136218231078990999448652468262416972035911852507045361090559
- )
+
+ # http://tools.ietf.org/html/rfc3526#section-3
+ P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF # noqa
G = 2
- name = 'diffie-hellman-group14-sha1'
+
+ name = "diffie-hellman-group14-sha1"
hash_algo = sha1
class KexGroup14SHA256(KexGroup14):
- name = 'diffie-hellman-group14-sha256'
+ name = "diffie-hellman-group14-sha256"
hash_algo = sha256
diff --git a/paramiko/kex_group16.py b/paramiko/kex_group16.py
index f1223757..c675f877 100644
--- a/paramiko/kex_group16.py
+++ b/paramiko/kex_group16.py
@@ -1,16 +1,35 @@
+# Copyright (C) 2019 Edgar Sousa <https://github.com/edgsousa>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
"""
Standard SSH key exchange ("kex" if you wanna sound cool). Diffie-Hellman of
4096 bit key halves, using a known "p" prime and "g" generator.
"""
+
from paramiko.kex_group1 import KexGroup1
from hashlib import sha512
class KexGroup16SHA512(KexGroup1):
- name = 'diffie-hellman-group16-sha512'
- P = (
- 1044388881413152506679602719846529545831269060992135009022588756444338172022322690710444046669809783930111585737890362691860127079270495454517218673016928427459146001866885779762982229321192368303346235204368051010309155674155697460347176946394076535157284994895284821633700921811716738972451834979455897010306333468590751358365138782250372269117968985194322444535687415522007151638638141456178420621277822674995027990278673458629544391736919766299005511505446177668154446234882665961680796576903199116089347634947187778906528008004756692571666922964122566174582776707332452371001272163776841229318324903125740713574141005124561965913888899753461735347970011693256316751660678950830027510255804846105583465055446615090444309583050775808509297040039680057435342253926566240898195863631588888936364129920059308455669454034010391478238784189888594672336242763795138176353222845524644040094258962433613354036104643881925238489224010194193088911666165584229424668165441688927790460608264864204237717002054744337988941974661214699689706521543006262604535890998125752275942608772174376107314217749233048217904944409836238235772306749874396760463376480215133461333478395682746608242585133953883882226786118030184028136755970045385534758453247
- )
+ name = "diffie-hellman-group16-sha512"
+ # http://tools.ietf.org/html/rfc3526#section-5
+ P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C934063199FFFFFFFFFFFFFFFF # noqa
G = 2
- name = 'diffie-hellman-group16-sha512'
+
+ name = "diffie-hellman-group16-sha512"
hash_algo = sha512
diff --git a/paramiko/kex_gss.py b/paramiko/kex_gss.py
index 50d792e4..2a5f29e3 100644
--- a/paramiko/kex_gss.py
+++ b/paramiko/kex_gss.py
@@ -1,3 +1,25 @@
+# Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
+# Copyright (C) 2013-2014 science + computing ag
+# Author: Sebastian Deiss <sebastian.deiss@t-online.de>
+#
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
+
"""
This module provides GSS-API / SSPI Key Exchange as defined in :rfc:`4462`.
@@ -14,20 +36,41 @@ This module provides GSS-API / SSPI Key Exchange as defined in :rfc:`4462`.
.. versionadded:: 1.15
"""
+
import os
from hashlib import sha1
-from paramiko.common import DEBUG, max_byte, zero_byte, byte_chr, byte_mask, byte_ord
+
+from paramiko.common import (
+ DEBUG,
+ max_byte,
+ zero_byte,
+ byte_chr,
+ byte_mask,
+ byte_ord,
+)
from paramiko import util
from paramiko.message import Message
from paramiko.ssh_exception import SSHException
-(MSG_KEXGSS_INIT, MSG_KEXGSS_CONTINUE, MSG_KEXGSS_COMPLETE,
- MSG_KEXGSS_HOSTKEY, MSG_KEXGSS_ERROR) = range(30, 35)
-MSG_KEXGSS_GROUPREQ, MSG_KEXGSS_GROUP = range(40, 42)
-(c_MSG_KEXGSS_INIT, c_MSG_KEXGSS_CONTINUE, c_MSG_KEXGSS_COMPLETE,
- c_MSG_KEXGSS_HOSTKEY, c_MSG_KEXGSS_ERROR) = [byte_chr(c) for c in range
- (30, 35)]
-c_MSG_KEXGSS_GROUPREQ, c_MSG_KEXGSS_GROUP = [byte_chr(c) for c in range(40, 42)
- ]
+
+
+(
+ MSG_KEXGSS_INIT,
+ MSG_KEXGSS_CONTINUE,
+ MSG_KEXGSS_COMPLETE,
+ MSG_KEXGSS_HOSTKEY,
+ MSG_KEXGSS_ERROR,
+) = range(30, 35)
+(MSG_KEXGSS_GROUPREQ, MSG_KEXGSS_GROUP) = range(40, 42)
+(
+ c_MSG_KEXGSS_INIT,
+ c_MSG_KEXGSS_CONTINUE,
+ c_MSG_KEXGSS_COMPLETE,
+ c_MSG_KEXGSS_HOSTKEY,
+ c_MSG_KEXGSS_ERROR,
+) = [byte_chr(c) for c in range(30, 35)]
+(c_MSG_KEXGSS_GROUPREQ, c_MSG_KEXGSS_GROUP) = [
+ byte_chr(c) for c in range(40, 42)
+]
class KexGSSGroup1:
@@ -35,13 +78,13 @@ class KexGSSGroup1:
GSS-API / SSPI Authenticated Diffie-Hellman Key Exchange as defined in `RFC
4462 Section 2 <https://tools.ietf.org/html/rfc4462.html#section-2>`_
"""
- P = (
- 179769313486231590770839156793787453197860296048756011706444423684197180216158519368947833795864925541502180565485980503646440548199239100050792877003355816639229553136239076508735759914822574862575007425302077447712589550957937778424442426617334727629299387668709205606050270810842907692932019128194467627007
- )
+
+ # draft-ietf-secsh-transport-09.txt, page 17
+ P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF # noqa
G = 2
- b7fffffffffffffff = byte_chr(127) + max_byte * 7
- b0000000000000000 = zero_byte * 8
- NAME = 'gss-group1-sha1-toWM5Slw5Ew8Mqkay+al2g=='
+ b7fffffffffffffff = byte_chr(0x7F) + max_byte * 7 # noqa
+ b0000000000000000 = zero_byte * 8 # noqa
+ NAME = "gss-group1-sha1-toWM5Slw5Ew8Mqkay+al2g=="
def __init__(self, transport):
self.transport = transport
@@ -55,7 +98,27 @@ class KexGSSGroup1:
"""
Start the GSS-API / SSPI Authenticated Diffie-Hellman Key Exchange.
"""
- pass
+ self._generate_x()
+ if self.transport.server_mode:
+ # compute f = g^x mod p, but don't send it yet
+ self.f = pow(self.G, self.x, self.P)
+ self.transport._expect_packet(MSG_KEXGSS_INIT)
+ return
+ # compute e = g^x mod p (where g=2), and send it
+ self.e = pow(self.G, self.x, self.P)
+ # Initialize GSS-API Key Exchange
+ self.gss_host = self.transport.gss_host
+ m = Message()
+ m.add_byte(c_MSG_KEXGSS_INIT)
+ m.add_string(self.kexgss.ssh_init_sec_context(target=self.gss_host))
+ m.add_mpint(self.e)
+ self.transport._send_message(m)
+ self.transport._expect_packet(
+ MSG_KEXGSS_HOSTKEY,
+ MSG_KEXGSS_CONTINUE,
+ MSG_KEXGSS_COMPLETE,
+ MSG_KEXGSS_ERROR,
+ )
def parse_next(self, ptype, m):
"""
@@ -64,7 +127,20 @@ class KexGSSGroup1:
:param ptype: The (string) type of the incoming packet
:param `.Message` m: The packet content
"""
- pass
+ if self.transport.server_mode and (ptype == MSG_KEXGSS_INIT):
+ return self._parse_kexgss_init(m)
+ elif not self.transport.server_mode and (ptype == MSG_KEXGSS_HOSTKEY):
+ return self._parse_kexgss_hostkey(m)
+ elif self.transport.server_mode and (ptype == MSG_KEXGSS_CONTINUE):
+ return self._parse_kexgss_continue(m)
+ elif not self.transport.server_mode and (ptype == MSG_KEXGSS_COMPLETE):
+ return self._parse_kexgss_complete(m)
+ elif ptype == MSG_KEXGSS_ERROR:
+ return self._parse_kexgss_error(m)
+ msg = "GSS KexGroup1 asked to handle packet type {:d}"
+ raise SSHException(msg.format(ptype))
+
+ # ## internals...
def _generate_x(self):
"""
@@ -74,7 +150,13 @@ class KexGSSGroup1:
potential x where the first 63 bits are 1, because some of those will
be larger than q (but this is a tiny tiny subset of potential x).
"""
- pass
+ while 1:
+ x_bytes = os.urandom(128)
+ x_bytes = byte_mask(x_bytes[0], 0x7F) + x_bytes[1:]
+ first = x_bytes[:8]
+ if first not in (self.b7fffffffffffffff, self.b0000000000000000):
+ break
+ self.x = util.inflate_long(x_bytes)
def _parse_kexgss_hostkey(self, m):
"""
@@ -82,7 +164,12 @@ class KexGSSGroup1:
:param `.Message` m: The content of the SSH2_MSG_KEXGSS_HOSTKEY message
"""
- pass
+ # client mode
+ host_key = m.get_string()
+ self.transport.host_key = host_key
+ sig = m.get_string()
+ self.transport._verify_key(host_key, sig)
+ self.transport._expect_packet(MSG_KEXGSS_CONTINUE, MSG_KEXGSS_COMPLETE)
def _parse_kexgss_continue(self, m):
"""
@@ -91,7 +178,21 @@ class KexGSSGroup1:
:param `.Message` m: The content of the SSH2_MSG_KEXGSS_CONTINUE
message
"""
- pass
+ if not self.transport.server_mode:
+ srv_token = m.get_string()
+ m = Message()
+ m.add_byte(c_MSG_KEXGSS_CONTINUE)
+ m.add_string(
+ self.kexgss.ssh_init_sec_context(
+ target=self.gss_host, recv_token=srv_token
+ )
+ )
+ self.transport.send_message(m)
+ self.transport._expect_packet(
+ MSG_KEXGSS_CONTINUE, MSG_KEXGSS_COMPLETE, MSG_KEXGSS_ERROR
+ )
+ else:
+ pass
def _parse_kexgss_complete(self, m):
"""
@@ -100,7 +201,43 @@ class KexGSSGroup1:
:param `.Message` m: The content of the
SSH2_MSG_KEXGSS_COMPLETE message
"""
- pass
+ # client mode
+ if self.transport.host_key is None:
+ self.transport.host_key = NullHostKey()
+ self.f = m.get_mpint()
+ if (self.f < 1) or (self.f > self.P - 1):
+ raise SSHException('Server kex "f" is out of range')
+ mic_token = m.get_string()
+ # This must be TRUE, if there is a GSS-API token in this message.
+ bool = m.get_boolean()
+ srv_token = None
+ if bool:
+ srv_token = m.get_string()
+ K = pow(self.f, self.x, self.P)
+ # okay, build up the hash H of
+ # (V_C || V_S || I_C || I_S || K_S || e || f || K)
+ hm = Message()
+ hm.add(
+ self.transport.local_version,
+ self.transport.remote_version,
+ self.transport.local_kex_init,
+ self.transport.remote_kex_init,
+ )
+ hm.add_string(self.transport.host_key.__str__())
+ hm.add_mpint(self.e)
+ hm.add_mpint(self.f)
+ hm.add_mpint(K)
+ H = sha1(str(hm)).digest()
+ self.transport._set_K_H(K, H)
+ if srv_token is not None:
+ self.kexgss.ssh_init_sec_context(
+ target=self.gss_host, recv_token=srv_token
+ )
+ self.kexgss.ssh_check_mic(mic_token, H)
+ else:
+ self.kexgss.ssh_check_mic(mic_token, H)
+ self.transport.gss_kex_used = True
+ self.transport._activate_outbound()
def _parse_kexgss_init(self, m):
"""
@@ -108,7 +245,55 @@ class KexGSSGroup1:
:param `.Message` m: The content of the SSH2_MSG_KEXGSS_INIT message
"""
- pass
+ # server mode
+ client_token = m.get_string()
+ self.e = m.get_mpint()
+ if (self.e < 1) or (self.e > self.P - 1):
+ raise SSHException('Client kex "e" is out of range')
+ K = pow(self.e, self.x, self.P)
+ self.transport.host_key = NullHostKey()
+ key = self.transport.host_key.__str__()
+ # okay, build up the hash H of
+ # (V_C || V_S || I_C || I_S || K_S || e || f || K)
+ hm = Message()
+ hm.add(
+ self.transport.remote_version,
+ self.transport.local_version,
+ self.transport.remote_kex_init,
+ self.transport.local_kex_init,
+ )
+ hm.add_string(key)
+ hm.add_mpint(self.e)
+ hm.add_mpint(self.f)
+ hm.add_mpint(K)
+ H = sha1(hm.asbytes()).digest()
+ self.transport._set_K_H(K, H)
+ srv_token = self.kexgss.ssh_accept_sec_context(
+ self.gss_host, client_token
+ )
+ m = Message()
+ if self.kexgss._gss_srv_ctxt_status:
+ mic_token = self.kexgss.ssh_get_mic(
+ self.transport.session_id, gss_kex=True
+ )
+ m.add_byte(c_MSG_KEXGSS_COMPLETE)
+ m.add_mpint(self.f)
+ m.add_string(mic_token)
+ if srv_token is not None:
+ m.add_boolean(True)
+ m.add_string(srv_token)
+ else:
+ m.add_boolean(False)
+ self.transport._send_message(m)
+ self.transport.gss_kex_used = True
+ self.transport._activate_outbound()
+ else:
+ m.add_byte(c_MSG_KEXGSS_CONTINUE)
+ m.add_string(srv_token)
+ self.transport._send_message(m)
+ self.transport._expect_packet(
+ MSG_KEXGSS_CONTINUE, MSG_KEXGSS_COMPLETE, MSG_KEXGSS_ERROR
+ )
def _parse_kexgss_error(self, m):
"""
@@ -121,7 +306,19 @@ class KexGSSGroup1:
the error message and the language tag of the
message
"""
- pass
+ maj_status = m.get_int()
+ min_status = m.get_int()
+ err_msg = m.get_string()
+ m.get_string() # we don't care about the language!
+ raise SSHException(
+ """GSS-API Error:
+Major Status: {}
+Minor Status: {}
+Error Message: {}
+""".format(
+ maj_status, min_status, err_msg
+ )
+ )
class KexGSSGroup14(KexGSSGroup1):
@@ -130,11 +327,10 @@ class KexGSSGroup14(KexGSSGroup1):
in `RFC 4462 Section 2
<https://tools.ietf.org/html/rfc4462.html#section-2>`_
"""
- P = (
- 32317006071311007300338913926423828248817941241140239112842009751400741706634354222619689417363569347117901737909704191754605873209195028853758986185622153212175412514901774520270235796078236248884246189477587641105928646099411723245426622522193230540919037680524235519125679715870117001058055877651038861847280257976054903569732561526167081339361799541336476559160368317896729073178384589680639671900977202194168647225871031411336429319536193471636533209717077448227988588565369208645296636077250268955505928362751121174096972998068410554359584866583291642136218231078990999448652468262416972035911852507045361090559
- )
+
+ P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF # noqa
G = 2
- NAME = 'gss-group14-sha1-toWM5Slw5Ew8Mqkay+al2g=='
+ NAME = "gss-group14-sha1-toWM5Slw5Ew8Mqkay+al2g=="
class KexGSSGex:
@@ -142,7 +338,8 @@ class KexGSSGex:
GSS-API / SSPI Authenticated Diffie-Hellman Group Exchange as defined in
`RFC 4462 Section 2 <https://tools.ietf.org/html/rfc4462.html#section-2>`_
"""
- NAME = 'gss-gex-sha1-toWM5Slw5Ew8Mqkay+al2g=='
+
+ NAME = "gss-gex-sha1-toWM5Slw5Ew8Mqkay+al2g=="
min_bits = 1024
max_bits = 8192
preferred_bits = 2048
@@ -163,7 +360,20 @@ class KexGSSGex:
"""
Start the GSS-API / SSPI Authenticated Diffie-Hellman Group Exchange
"""
- pass
+ if self.transport.server_mode:
+ self.transport._expect_packet(MSG_KEXGSS_GROUPREQ)
+ return
+ # request a bit range: we accept (min_bits) to (max_bits), but prefer
+ # (preferred_bits). according to the spec, we shouldn't pull the
+ # minimum up above 1024.
+ self.gss_host = self.transport.gss_host
+ m = Message()
+ m.add_byte(c_MSG_KEXGSS_GROUPREQ)
+ m.add_int(self.min_bits)
+ m.add_int(self.preferred_bits)
+ m.add_int(self.max_bits)
+ self.transport._send_message(m)
+ self.transport._expect_packet(MSG_KEXGSS_GROUP)
def parse_next(self, ptype, m):
"""
@@ -172,7 +382,42 @@ class KexGSSGex:
:param ptype: The (string) type of the incoming packet
:param `.Message` m: The packet content
"""
- pass
+ if ptype == MSG_KEXGSS_GROUPREQ:
+ return self._parse_kexgss_groupreq(m)
+ elif ptype == MSG_KEXGSS_GROUP:
+ return self._parse_kexgss_group(m)
+ elif ptype == MSG_KEXGSS_INIT:
+ return self._parse_kexgss_gex_init(m)
+ elif ptype == MSG_KEXGSS_HOSTKEY:
+ return self._parse_kexgss_hostkey(m)
+ elif ptype == MSG_KEXGSS_CONTINUE:
+ return self._parse_kexgss_continue(m)
+ elif ptype == MSG_KEXGSS_COMPLETE:
+ return self._parse_kexgss_complete(m)
+ elif ptype == MSG_KEXGSS_ERROR:
+ return self._parse_kexgss_error(m)
+ msg = "KexGex asked to handle packet type {:d}"
+ raise SSHException(msg.format(ptype))
+
+ # ## internals...
+
+ def _generate_x(self):
+ # generate an "x" (1 < x < (p-1)/2).
+ q = (self.p - 1) // 2
+ qnorm = util.deflate_long(q, 0)
+ qhbyte = byte_ord(qnorm[0])
+ byte_count = len(qnorm)
+ qmask = 0xFF
+ while not (qhbyte & 0x80):
+ qhbyte <<= 1
+ qmask >>= 1
+ while True:
+ x_bytes = os.urandom(byte_count)
+ x_bytes = byte_mask(x_bytes[0], qmask) + x_bytes[1:]
+ x = util.inflate_long(x_bytes, 1)
+ if (x > 1) and (x < q):
+ break
+ self.x = x
def _parse_kexgss_groupreq(self, m):
"""
@@ -181,7 +426,42 @@ class KexGSSGex:
:param `.Message` m: The content of the
SSH2_MSG_KEXGSS_GROUPREQ message
"""
- pass
+ minbits = m.get_int()
+ preferredbits = m.get_int()
+ maxbits = m.get_int()
+ # smoosh the user's preferred size into our own limits
+ if preferredbits > self.max_bits:
+ preferredbits = self.max_bits
+ if preferredbits < self.min_bits:
+ preferredbits = self.min_bits
+ # fix min/max if they're inconsistent. technically, we could just pout
+ # and hang up, but there's no harm in giving them the benefit of the
+ # doubt and just picking a bitsize for them.
+ if minbits > preferredbits:
+ minbits = preferredbits
+ if maxbits < preferredbits:
+ maxbits = preferredbits
+ # now save a copy
+ self.min_bits = minbits
+ self.preferred_bits = preferredbits
+ self.max_bits = maxbits
+ # generate prime
+ pack = self.transport._get_modulus_pack()
+ if pack is None:
+ raise SSHException("Can't do server-side gex with no modulus pack")
+ self.transport._log(
+ DEBUG, # noqa
+ "Picking p ({} <= {} <= {} bits)".format(
+ minbits, preferredbits, maxbits
+ ),
+ )
+ self.g, self.p = pack.get_modulus(minbits, preferredbits, maxbits)
+ m = Message()
+ m.add_byte(c_MSG_KEXGSS_GROUP)
+ m.add_mpint(self.p)
+ m.add_mpint(self.g)
+ self.transport._send_message(m)
+ self.transport._expect_packet(MSG_KEXGSS_INIT)
def _parse_kexgss_group(self, m):
"""
@@ -189,7 +469,32 @@ class KexGSSGex:
:param `Message` m: The content of the SSH2_MSG_KEXGSS_GROUP message
"""
- pass
+ self.p = m.get_mpint()
+ self.g = m.get_mpint()
+ # reject if p's bit length < 1024 or > 8192
+ bitlen = util.bit_length(self.p)
+ if (bitlen < 1024) or (bitlen > 8192):
+ raise SSHException(
+ "Server-generated gex p (don't ask) is out of range "
+ "({} bits)".format(bitlen)
+ )
+ self.transport._log(
+ DEBUG, "Got server p ({} bits)".format(bitlen)
+ ) # noqa
+ self._generate_x()
+ # now compute e = g^x mod p
+ self.e = pow(self.g, self.x, self.p)
+ m = Message()
+ m.add_byte(c_MSG_KEXGSS_INIT)
+ m.add_string(self.kexgss.ssh_init_sec_context(target=self.gss_host))
+ m.add_mpint(self.e)
+ self.transport._send_message(m)
+ self.transport._expect_packet(
+ MSG_KEXGSS_HOSTKEY,
+ MSG_KEXGSS_CONTINUE,
+ MSG_KEXGSS_COMPLETE,
+ MSG_KEXGSS_ERROR,
+ )
def _parse_kexgss_gex_init(self, m):
"""
@@ -197,7 +502,61 @@ class KexGSSGex:
:param `Message` m: The content of the SSH2_MSG_KEXGSS_INIT message
"""
- pass
+ client_token = m.get_string()
+ self.e = m.get_mpint()
+ if (self.e < 1) or (self.e > self.p - 1):
+ raise SSHException('Client kex "e" is out of range')
+ self._generate_x()
+ self.f = pow(self.g, self.x, self.p)
+ K = pow(self.e, self.x, self.p)
+ self.transport.host_key = NullHostKey()
+ key = self.transport.host_key.__str__()
+ # okay, build up the hash H of
+ # (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K) # noqa
+ hm = Message()
+ hm.add(
+ self.transport.remote_version,
+ self.transport.local_version,
+ self.transport.remote_kex_init,
+ self.transport.local_kex_init,
+ key,
+ )
+ hm.add_int(self.min_bits)
+ hm.add_int(self.preferred_bits)
+ hm.add_int(self.max_bits)
+ hm.add_mpint(self.p)
+ hm.add_mpint(self.g)
+ hm.add_mpint(self.e)
+ hm.add_mpint(self.f)
+ hm.add_mpint(K)
+ H = sha1(hm.asbytes()).digest()
+ self.transport._set_K_H(K, H)
+ srv_token = self.kexgss.ssh_accept_sec_context(
+ self.gss_host, client_token
+ )
+ m = Message()
+ if self.kexgss._gss_srv_ctxt_status:
+ mic_token = self.kexgss.ssh_get_mic(
+ self.transport.session_id, gss_kex=True
+ )
+ m.add_byte(c_MSG_KEXGSS_COMPLETE)
+ m.add_mpint(self.f)
+ m.add_string(mic_token)
+ if srv_token is not None:
+ m.add_boolean(True)
+ m.add_string(srv_token)
+ else:
+ m.add_boolean(False)
+ self.transport._send_message(m)
+ self.transport.gss_kex_used = True
+ self.transport._activate_outbound()
+ else:
+ m.add_byte(c_MSG_KEXGSS_CONTINUE)
+ m.add_string(srv_token)
+ self.transport._send_message(m)
+ self.transport._expect_packet(
+ MSG_KEXGSS_CONTINUE, MSG_KEXGSS_COMPLETE, MSG_KEXGSS_ERROR
+ )
def _parse_kexgss_hostkey(self, m):
"""
@@ -205,7 +564,12 @@ class KexGSSGex:
:param `Message` m: The content of the SSH2_MSG_KEXGSS_HOSTKEY message
"""
- pass
+ # client mode
+ host_key = m.get_string()
+ self.transport.host_key = host_key
+ sig = m.get_string()
+ self.transport._verify_key(host_key, sig)
+ self.transport._expect_packet(MSG_KEXGSS_CONTINUE, MSG_KEXGSS_COMPLETE)
def _parse_kexgss_continue(self, m):
"""
@@ -213,7 +577,21 @@ class KexGSSGex:
:param `Message` m: The content of the SSH2_MSG_KEXGSS_CONTINUE message
"""
- pass
+ if not self.transport.server_mode:
+ srv_token = m.get_string()
+ m = Message()
+ m.add_byte(c_MSG_KEXGSS_CONTINUE)
+ m.add_string(
+ self.kexgss.ssh_init_sec_context(
+ target=self.gss_host, recv_token=srv_token
+ )
+ )
+ self.transport.send_message(m)
+ self.transport._expect_packet(
+ MSG_KEXGSS_CONTINUE, MSG_KEXGSS_COMPLETE, MSG_KEXGSS_ERROR
+ )
+ else:
+ pass
def _parse_kexgss_complete(self, m):
"""
@@ -221,7 +599,49 @@ class KexGSSGex:
:param `Message` m: The content of the SSH2_MSG_KEXGSS_COMPLETE message
"""
- pass
+ if self.transport.host_key is None:
+ self.transport.host_key = NullHostKey()
+ self.f = m.get_mpint()
+ mic_token = m.get_string()
+ # This must be TRUE, if there is a GSS-API token in this message.
+ bool = m.get_boolean()
+ srv_token = None
+ if bool:
+ srv_token = m.get_string()
+ if (self.f < 1) or (self.f > self.p - 1):
+ raise SSHException('Server kex "f" is out of range')
+ K = pow(self.f, self.x, self.p)
+ # okay, build up the hash H of
+ # (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K) # noqa
+ hm = Message()
+ hm.add(
+ self.transport.local_version,
+ self.transport.remote_version,
+ self.transport.local_kex_init,
+ self.transport.remote_kex_init,
+ self.transport.host_key.__str__(),
+ )
+ if not self.old_style:
+ hm.add_int(self.min_bits)
+ hm.add_int(self.preferred_bits)
+ if not self.old_style:
+ hm.add_int(self.max_bits)
+ hm.add_mpint(self.p)
+ hm.add_mpint(self.g)
+ hm.add_mpint(self.e)
+ hm.add_mpint(self.f)
+ hm.add_mpint(K)
+ H = sha1(hm.asbytes()).digest()
+ self.transport._set_K_H(K, H)
+ if srv_token is not None:
+ self.kexgss.ssh_init_sec_context(
+ target=self.gss_host, recv_token=srv_token
+ )
+ self.kexgss.ssh_check_mic(mic_token, H)
+ else:
+ self.kexgss.ssh_check_mic(mic_token, H)
+ self.transport.gss_kex_used = True
+ self.transport._activate_outbound()
def _parse_kexgss_error(self, m):
"""
@@ -234,7 +654,19 @@ class KexGSSGex:
the error message and the language tag of the
message
"""
- pass
+ maj_status = m.get_int()
+ min_status = m.get_int()
+ err_msg = m.get_string()
+ m.get_string() # we don't care about the language (lang_tag)!
+ raise SSHException(
+ """GSS-API Error:
+Major Status: {}
+Minor Status: {}
+Error Message: {}
+""".format(
+ maj_status, min_status, err_msg
+ )
+ )
class NullHostKey:
@@ -245,7 +677,10 @@ class NullHostKey:
"""
def __init__(self):
- self.key = ''
+ self.key = ""
def __str__(self):
return self.key
+
+ def get_name(self):
+ return self.key
diff --git a/paramiko/message.py b/paramiko/message.py
index 7e6e2c5a..8c2b3bd0 100644
--- a/paramiko/message.py
+++ b/paramiko/message.py
@@ -1,8 +1,28 @@
+# Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
"""
Implementation of an SSH2 "message".
"""
+
import struct
from io import BytesIO
+
from paramiko import util
from paramiko.common import zero_byte, max_byte, one_byte
from paramiko.util import u
@@ -18,7 +38,8 @@ class Message:
exposed for people implementing custom extensions, or features that
paramiko doesn't support yet.
"""
- big_int = 4278190080
+
+ big_int = 0xFF000000
def __init__(self, content=None):
"""
@@ -40,27 +61,31 @@ class Message:
"""
Returns a string representation of this object, for debugging.
"""
- return 'paramiko.Message(' + repr(self.packet.getvalue()) + ')'
+ return "paramiko.Message(" + repr(self.packet.getvalue()) + ")"
+ # TODO 4.0: just merge into __bytes__ (everywhere)
def asbytes(self):
"""
Return the byte stream content of this Message, as a `bytes`.
"""
- pass
+ return self.packet.getvalue()
def rewind(self):
"""
Rewind the message to the beginning as if no items had been parsed
out of it yet.
"""
- pass
+ self.packet.seek(0)
def get_remainder(self):
"""
Return the `bytes` of this message that haven't already been parsed and
returned.
"""
- pass
+ position = self.packet.tell()
+ remainder = self.packet.read()
+ self.packet.seek(position)
+ return remainder
def get_so_far(self):
"""
@@ -68,7 +93,9 @@ class Message:
returned. The string passed into a message's constructor can be
regenerated by concatenating ``get_so_far`` and `get_remainder`.
"""
- pass
+ position = self.packet.tell()
+ self.rewind()
+ return self.packet.read(position)
def get_bytes(self, n):
"""
@@ -77,17 +104,29 @@ class Message:
string of ``n`` zero bytes if there weren't ``n`` bytes remaining in
the message.
"""
- pass
+ b = self.packet.read(n)
+ max_pad_size = 1 << 20 # Limit padding to 1 MB
+ if len(b) < n < max_pad_size:
+ return b + zero_byte * (n - len(b))
+ return b
def get_byte(self):
- "\n Return the next byte of the message, without decomposing it. This\n is equivalent to `get_bytes(1) <get_bytes>`.\n\n :return:\n the next (`bytes`) byte of the message, or ``b'\x00'`` if there\n aren't any bytes remaining.\n "
- pass
+ """
+ Return the next byte of the message, without decomposing it. This
+ is equivalent to `get_bytes(1) <get_bytes>`.
+
+ :return:
+ the next (`bytes`) byte of the message, or ``b'\000'`` if there
+ aren't any bytes remaining.
+ """
+ return self.get_bytes(1)
def get_boolean(self):
"""
Fetch a boolean from the stream.
"""
- pass
+ b = self.get_bytes(1)
+ return b != zero_byte
def get_adaptive_int(self):
"""
@@ -95,13 +134,17 @@ class Message:
:return: a 32-bit unsigned `int`.
"""
- pass
+ byte = self.get_bytes(1)
+ if byte == max_byte:
+ return util.inflate_long(self.get_binary())
+ byte += self.get_bytes(3)
+ return struct.unpack(">I", byte)[0]
def get_int(self):
"""
Fetch an int from the stream.
"""
- pass
+ return struct.unpack(">I", self.get_bytes(4))[0]
def get_int64(self):
"""
@@ -109,7 +152,7 @@ class Message:
:return: a 64-bit unsigned integer (`int`).
"""
- pass
+ return struct.unpack(">Q", self.get_bytes(8))[0]
def get_mpint(self):
"""
@@ -117,16 +160,20 @@ class Message:
:return: an arbitrary-length integer (`int`).
"""
- pass
+ return util.inflate_long(self.get_binary())
+ # TODO 4.0: depending on where this is used internally or downstream, force
+ # users to specify get_binary instead and delete this.
def get_string(self):
"""
Fetch a "string" from the stream. This will actually be a `bytes`
object, and may contain unprintable characters. (It's not unheard of
for a string to contain another byte-stream message.)
"""
- pass
+ return self.get_bytes(self.get_int())
+ # TODO 4.0: also consider having this take over the get_string name, and
+ # remove this name instead.
def get_text(self):
"""
Fetch a Unicode string from the stream.
@@ -134,13 +181,13 @@ class Message:
This currently operates by attempting to encode the next "string" as
``utf-8``.
"""
- pass
+ return u(self.get_string())
def get_binary(self):
"""
Alias for `get_string` (obtains a bytestring).
"""
- pass
+ return self.get_bytes(self.get_int())
def get_list(self):
"""
@@ -148,7 +195,7 @@ class Message:
These are trivially encoded as comma-separated values in a string.
"""
- pass
+ return self.get_text().split(",")
def add_bytes(self, b):
"""
@@ -156,7 +203,8 @@ class Message:
:param bytes b: bytes to add
"""
- pass
+ self.packet.write(b)
+ return self
def add_byte(self, b):
"""
@@ -164,7 +212,8 @@ class Message:
:param bytes b: byte to add
"""
- pass
+ self.packet.write(b)
+ return self
def add_boolean(self, b):
"""
@@ -172,7 +221,11 @@ class Message:
:param bool b: boolean value to add
"""
- pass
+ if b:
+ self.packet.write(one_byte)
+ else:
+ self.packet.write(zero_byte)
+ return self
def add_int(self, n):
"""
@@ -180,7 +233,8 @@ class Message:
:param int n: integer to add
"""
- pass
+ self.packet.write(struct.pack(">I", n))
+ return self
def add_adaptive_int(self, n):
"""
@@ -188,7 +242,12 @@ class Message:
:param int n: integer to add
"""
- pass
+ if n >= Message.big_int:
+ self.packet.write(max_byte)
+ self.add_string(util.deflate_long(n))
+ else:
+ self.packet.write(struct.pack(">I", n))
+ return self
def add_int64(self, n):
"""
@@ -196,7 +255,8 @@ class Message:
:param int n: long int to add
"""
- pass
+ self.packet.write(struct.pack(">Q", n))
+ return self
def add_mpint(self, z):
"""
@@ -205,17 +265,23 @@ class Message:
:param int z: long int to add
"""
- pass
+ self.add_string(util.deflate_long(z))
+ return self
+ # TODO: see the TODO for get_string/get_text/et al, this should change
+ # to match.
def add_string(self, s):
"""
Add a bytestring to the stream.
:param byte s: bytestring to add
"""
- pass
+ s = util.asbytes(s)
+ self.add_int(len(s))
+ self.packet.write(s)
+ return self
- def add_list(self, l):
+ def add_list(self, l): # noqa: E741
"""
Add a list of strings to the stream. They are encoded identically to
a single string of values separated by commas. (Yes, really, that's
@@ -223,8 +289,21 @@ class Message:
:param l: list of strings to add
"""
- pass
+ self.add_string(",".join(l))
+ return self
+
+ def _add(self, i):
+ if type(i) is bool:
+ return self.add_boolean(i)
+ elif isinstance(i, int):
+ return self.add_adaptive_int(i)
+ elif type(i) is list:
+ return self.add_list(i)
+ else:
+ return self.add_string(i)
+ # TODO: this would never have worked for unicode strings under Python 3,
+ # guessing nobody/nothing ever used it for that purpose?
def add(self, *seq):
"""
Add a sequence of items to the stream. The values are encoded based
@@ -235,4 +314,5 @@ class Message:
:param seq: the sequence of items
"""
- pass
+ for item in seq:
+ self._add(item)
diff --git a/paramiko/packet.py b/paramiko/packet.py
index 92f24b8c..1274a23c 100644
--- a/paramiko/packet.py
+++ b/paramiko/packet.py
@@ -1,6 +1,25 @@
+# Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
"""
Packet handling
"""
+
import errno
import os
import socket
@@ -8,27 +27,55 @@ import struct
import threading
import time
from hmac import HMAC
+
from paramiko import util
-from paramiko.common import linefeed_byte, cr_byte_value, MSG_NAMES, DEBUG, xffffffff, zero_byte, byte_ord
+from paramiko.common import (
+ linefeed_byte,
+ cr_byte_value,
+ MSG_NAMES,
+ DEBUG,
+ xffffffff,
+ zero_byte,
+ byte_ord,
+)
from paramiko.util import u
from paramiko.ssh_exception import SSHException, ProxyCommandFailure
from paramiko.message import Message
+def compute_hmac(key, message, digest_class):
+ return HMAC(key, message, digest_class).digest()
+
+
class NeedRekeyException(Exception):
"""
Exception indicating a rekey is needed.
"""
+
pass
+def first_arg(e):
+ arg = None
+ if type(e.args) is tuple and len(e.args) > 0:
+ arg = e.args[0]
+ return arg
+
+
class Packetizer:
"""
Implementation of the base SSH packet protocol.
"""
+
+ # READ the secsh RFC's before raising these values. if anything,
+ # they should probably be lower.
REKEY_PACKETS = pow(2, 29)
REKEY_BYTES = pow(2, 29)
+
+ # Allow receiving this many packets after a re-key request before
+ # terminating
REKEY_PACKETS_OVERFLOW_MAX = pow(2, 29)
+ # Allow receiving this many bytes after a re-key request before terminating
REKEY_BYTES_OVERFLOW_MAX = pow(2, 29)
def __init__(self, socket):
@@ -40,12 +87,16 @@ class Packetizer:
self.__init_count = 0
self.__remainder = bytes()
self._initial_kex_done = False
+
+ # used for noticing when to re-key:
self.__sent_bytes = 0
self.__sent_packets = 0
self.__received_bytes = 0
self.__received_packets = 0
self.__received_bytes_overflow = 0
self.__received_packets_overflow = 0
+
+ # current inbound/outbound ciphering:
self.__block_size_out = 8
self.__block_size_in = 8
self.__mac_size_out = 0
@@ -63,35 +114,116 @@ class Packetizer:
self.__sequence_number_in = 0
self.__etm_out = False
self.__etm_in = False
+
+ # lock around outbound writes (packet computation)
self.__write_lock = threading.RLock()
+
+ # keepalives:
self.__keepalive_interval = 0
self.__keepalive_last = time.time()
self.__keepalive_callback = None
+
self.__timer = None
self.__handshake_complete = False
self.__timer_expired = False
+ @property
+ def closed(self):
+ return self.__closed
+
+ def reset_seqno_out(self):
+ self.__sequence_number_out = 0
+
+ def reset_seqno_in(self):
+ self.__sequence_number_in = 0
+
def set_log(self, log):
"""
Set the Python log object to use for logging.
"""
- pass
+ self.__logger = log
- def set_outbound_cipher(self, block_engine, block_size, mac_engine,
- mac_size, mac_key, sdctr=False, etm=False):
+ def set_outbound_cipher(
+ self,
+ block_engine,
+ block_size,
+ mac_engine,
+ mac_size,
+ mac_key,
+ sdctr=False,
+ etm=False,
+ ):
"""
Switch outbound data cipher.
:param etm: Set encrypt-then-mac from OpenSSH
"""
- pass
+ self.__block_engine_out = block_engine
+ self.__sdctr_out = sdctr
+ self.__block_size_out = block_size
+ self.__mac_engine_out = mac_engine
+ self.__mac_size_out = mac_size
+ self.__mac_key_out = mac_key
+ self.__sent_bytes = 0
+ self.__sent_packets = 0
+ self.__etm_out = etm
+ # wait until the reset happens in both directions before clearing
+ # rekey flag
+ self.__init_count |= 1
+ if self.__init_count == 3:
+ self.__init_count = 0
+ self.__need_rekey = False
- def set_inbound_cipher(self, block_engine, block_size, mac_engine,
- mac_size, mac_key, etm=False):
+ def set_inbound_cipher(
+ self,
+ block_engine,
+ block_size,
+ mac_engine,
+ mac_size,
+ mac_key,
+ etm=False,
+ ):
"""
Switch inbound data cipher.
:param etm: Set encrypt-then-mac from OpenSSH
"""
- pass
+ self.__block_engine_in = block_engine
+ self.__block_size_in = block_size
+ self.__mac_engine_in = mac_engine
+ self.__mac_size_in = mac_size
+ self.__mac_key_in = mac_key
+ self.__received_bytes = 0
+ self.__received_packets = 0
+ self.__received_bytes_overflow = 0
+ self.__received_packets_overflow = 0
+ self.__etm_in = etm
+ # wait until the reset happens in both directions before clearing
+ # rekey flag
+ self.__init_count |= 2
+ if self.__init_count == 3:
+ self.__init_count = 0
+ self.__need_rekey = False
+
+ def set_outbound_compressor(self, compressor):
+ self.__compress_engine_out = compressor
+
+ def set_inbound_compressor(self, compressor):
+ self.__compress_engine_in = compressor
+
+ def close(self):
+ self.__closed = True
+ self.__socket.close()
+
+ def set_hexdump(self, hexdump):
+ self.__dump_packets = hexdump
+
+ def get_hexdump(self):
+ return self.__dump_packets
+
+ def get_mac_size_in(self):
+ return self.__mac_size_in
+
+ def get_mac_size_out(self):
+ return self.__mac_size_out
def need_rekey(self):
"""
@@ -99,7 +231,7 @@ class Packetizer:
will be triggered during a packet read or write, so it should be
checked after every read or write, or at least after every few.
"""
- pass
+ return self.__need_rekey
def set_keepalive(self, interval, callback):
"""
@@ -107,7 +239,12 @@ class Packetizer:
no data read from or written to the socket, the callback will be
executed and the timer will be reset.
"""
- pass
+ self.__keepalive_interval = interval
+ self.__keepalive_callback = callback
+ self.__keepalive_last = time.time()
+
+ def read_timer(self):
+ self.__timer_expired = True
def start_handshake(self, timeout):
"""
@@ -117,7 +254,9 @@ class Packetizer:
:param float timeout: amount of seconds to wait before timing out
"""
- pass
+ if not self.__timer:
+ self.__timer = threading.Timer(float(timeout), self.read_timer)
+ self.__timer.start()
def handshake_timed_out(self):
"""
@@ -129,13 +268,20 @@ class Packetizer:
:return: handshake time out status, as a `bool`
"""
- pass
+ if not self.__timer:
+ return False
+ if self.__handshake_complete:
+ return False
+ return self.__timer_expired
def complete_handshake(self):
"""
Tells `Packetizer` that the handshake has completed.
"""
- pass
+ if self.__timer:
+ self.__timer.cancel()
+ self.__timer_expired = False
+ self.__handshake_complete = True
def read_all(self, n, check_rekey=False):
"""
@@ -148,20 +294,163 @@ class Packetizer:
``EOFError`` -- if the socket was closed before all the bytes could
be read
"""
- pass
+ out = bytes()
+ # handle over-reading from reading the banner line
+ if len(self.__remainder) > 0:
+ out = self.__remainder[:n]
+ self.__remainder = self.__remainder[n:]
+ n -= len(out)
+ while n > 0:
+ got_timeout = False
+ if self.handshake_timed_out():
+ raise EOFError()
+ try:
+ x = self.__socket.recv(n)
+ if len(x) == 0:
+ raise EOFError()
+ out += x
+ n -= len(x)
+ except socket.timeout:
+ got_timeout = True
+ except socket.error as e:
+ # on Linux, sometimes instead of socket.timeout, we get
+ # EAGAIN. this is a bug in recent (> 2.6.9) kernels but
+ # we need to work around it.
+ arg = first_arg(e)
+ if arg == errno.EAGAIN:
+ got_timeout = True
+ elif self.__closed:
+ raise EOFError()
+ else:
+ raise
+ if got_timeout:
+ if self.__closed:
+ raise EOFError()
+ if check_rekey and (len(out) == 0) and self.__need_rekey:
+ raise NeedRekeyException()
+ self._check_keepalive()
+ return out
+
+ def write_all(self, out):
+ self.__keepalive_last = time.time()
+ iteration_with_zero_as_return_value = 0
+ while len(out) > 0:
+ retry_write = False
+ try:
+ n = self.__socket.send(out)
+ except socket.timeout:
+ retry_write = True
+ except socket.error as e:
+ arg = first_arg(e)
+ if arg == errno.EAGAIN:
+ retry_write = True
+ else:
+ n = -1
+ except ProxyCommandFailure:
+ raise # so it doesn't get swallowed by the below catchall
+ except Exception:
+ # could be: (32, 'Broken pipe')
+ n = -1
+ if retry_write:
+ n = 0
+ if self.__closed:
+ n = -1
+ else:
+ if n == 0 and iteration_with_zero_as_return_value > 10:
+ # We shouldn't retry the write, but we didn't
+ # manage to send anything over the socket. This might be an
+ # indication that we have lost contact with the remote
+ # side, but are yet to receive an EOFError or other socket
+ # errors. Let's give it some iteration to try and catch up.
+ n = -1
+ iteration_with_zero_as_return_value += 1
+ if n < 0:
+ raise EOFError()
+ if n == len(out):
+ break
+ out = out[n:]
+ return
def readline(self, timeout):
"""
Read a line from the socket. We assume no data is pending after the
line, so it's okay to attempt large reads.
"""
- pass
+ buf = self.__remainder
+ while linefeed_byte not in buf:
+ buf += self._read_timeout(timeout)
+ n = buf.index(linefeed_byte)
+ self.__remainder = buf[n + 1 :]
+ buf = buf[:n]
+ if (len(buf) > 0) and (buf[-1] == cr_byte_value):
+ buf = buf[:-1]
+ return u(buf)
def send_message(self, data):
"""
Write a block of data using the current cipher, as an SSH block.
"""
- pass
+ # encrypt this sucka
+ data = data.asbytes()
+ cmd = byte_ord(data[0])
+ if cmd in MSG_NAMES:
+ cmd_name = MSG_NAMES[cmd]
+ else:
+ cmd_name = "${:x}".format(cmd)
+ orig_len = len(data)
+ self.__write_lock.acquire()
+ try:
+ if self.__compress_engine_out is not None:
+ data = self.__compress_engine_out(data)
+ packet = self._build_packet(data)
+ if self.__dump_packets:
+ self._log(
+ DEBUG,
+ "Write packet <{}>, length {}".format(cmd_name, orig_len),
+ )
+ self._log(DEBUG, util.format_binary(packet, "OUT: "))
+ if self.__block_engine_out is not None:
+ if self.__etm_out:
+ # packet length is not encrypted in EtM
+ out = packet[0:4] + self.__block_engine_out.update(
+ packet[4:]
+ )
+ else:
+ out = self.__block_engine_out.update(packet)
+ else:
+ out = packet
+ # + mac
+ if self.__block_engine_out is not None:
+ packed = struct.pack(">I", self.__sequence_number_out)
+ payload = packed + (out if self.__etm_out else packet)
+ out += compute_hmac(
+ self.__mac_key_out, payload, self.__mac_engine_out
+ )[: self.__mac_size_out]
+ next_seq = (self.__sequence_number_out + 1) & xffffffff
+ if next_seq == 0 and not self._initial_kex_done:
+ raise SSHException(
+ "Sequence number rolled over during initial kex!"
+ )
+ self.__sequence_number_out = next_seq
+ self.write_all(out)
+
+ self.__sent_bytes += len(out)
+ self.__sent_packets += 1
+ sent_too_much = (
+ self.__sent_packets >= self.REKEY_PACKETS
+ or self.__sent_bytes >= self.REKEY_BYTES
+ )
+ if sent_too_much and not self.__need_rekey:
+ # only ask once for rekeying
+ msg = "Rekeying (hit {} packets, {} bytes sent)"
+ self._log(
+ DEBUG, msg.format(self.__sent_packets, self.__sent_bytes)
+ )
+ self.__received_bytes_overflow = 0
+ self.__received_packets_overflow = 0
+ self._trigger_rekey()
+ finally:
+ self.__write_lock.release()
def read_message(self):
"""
@@ -171,4 +460,190 @@ class Packetizer:
:raises: `.SSHException` -- if the packet is mangled
:raises: `.NeedRekeyException` -- if the transport should rekey
"""
- pass
+ header = self.read_all(self.__block_size_in, check_rekey=True)
+ if self.__etm_in:
+ packet_size = struct.unpack(">I", header[:4])[0]
+ remaining = packet_size - self.__block_size_in + 4
+ packet = header[4:] + self.read_all(remaining, check_rekey=False)
+ mac = self.read_all(self.__mac_size_in, check_rekey=False)
+ mac_payload = (
+ struct.pack(">II", self.__sequence_number_in, packet_size)
+ + packet
+ )
+ my_mac = compute_hmac(
+ self.__mac_key_in, mac_payload, self.__mac_engine_in
+ )[: self.__mac_size_in]
+ if not util.constant_time_bytes_eq(my_mac, mac):
+ raise SSHException("Mismatched MAC")
+ header = packet
+
+ if self.__block_engine_in is not None:
+ header = self.__block_engine_in.update(header)
+ if self.__dump_packets:
+ self._log(DEBUG, util.format_binary(header, "IN: "))
+
+ # When ETM is in play, we've already read the packet size & decrypted
+ # everything, so just set the packet back to the header we obtained.
+ if self.__etm_in:
+ packet = header
+ # Otherwise, use the older non-ETM logic
+ else:
+ packet_size = struct.unpack(">I", header[:4])[0]
+
+ # leftover contains decrypted bytes from the first block (after the
+ # length field)
+ leftover = header[4:]
+ if (packet_size - len(leftover)) % self.__block_size_in != 0:
+ raise SSHException("Invalid packet blocking")
+ buf = self.read_all(
+ packet_size + self.__mac_size_in - len(leftover)
+ )
+ packet = buf[: packet_size - len(leftover)]
+ post_packet = buf[packet_size - len(leftover) :]
+
+ if self.__block_engine_in is not None:
+ packet = self.__block_engine_in.update(packet)
+ packet = leftover + packet
+
+ if self.__dump_packets:
+ self._log(DEBUG, util.format_binary(packet, "IN: "))
+
+ if self.__mac_size_in > 0 and not self.__etm_in:
+ mac = post_packet[: self.__mac_size_in]
+ mac_payload = (
+ struct.pack(">II", self.__sequence_number_in, packet_size)
+ + packet
+ )
+ my_mac = compute_hmac(
+ self.__mac_key_in, mac_payload, self.__mac_engine_in
+ )[: self.__mac_size_in]
+ if not util.constant_time_bytes_eq(my_mac, mac):
+ raise SSHException("Mismatched MAC")
+ padding = byte_ord(packet[0])
+ payload = packet[1 : packet_size - padding]
+
+ if self.__dump_packets:
+ self._log(
+ DEBUG,
+ "Got payload ({} bytes, {} padding)".format(
+ packet_size, padding
+ ),
+ )
+
+ if self.__compress_engine_in is not None:
+ payload = self.__compress_engine_in(payload)
+
+ msg = Message(payload[1:])
+ msg.seqno = self.__sequence_number_in
+ next_seq = (self.__sequence_number_in + 1) & xffffffff
+ if next_seq == 0 and not self._initial_kex_done:
+ raise SSHException(
+ "Sequence number rolled over during initial kex!"
+ )
+ self.__sequence_number_in = next_seq
+
+ # check for rekey
+ raw_packet_size = packet_size + self.__mac_size_in + 4
+ self.__received_bytes += raw_packet_size
+ self.__received_packets += 1
+ if self.__need_rekey:
+ # we've asked to rekey -- give them some packets to comply before
+ # dropping the connection
+ self.__received_bytes_overflow += raw_packet_size
+ self.__received_packets_overflow += 1
+ if (
+ self.__received_packets_overflow
+ >= self.REKEY_PACKETS_OVERFLOW_MAX
+ ) or (
+ self.__received_bytes_overflow >= self.REKEY_BYTES_OVERFLOW_MAX
+ ):
+ raise SSHException(
+ "Remote transport is ignoring rekey requests"
+ )
+ elif (self.__received_packets >= self.REKEY_PACKETS) or (
+ self.__received_bytes >= self.REKEY_BYTES
+ ):
+ # only ask once for rekeying
+ err = "Rekeying (hit {} packets, {} bytes received)"
+ self._log(
+ DEBUG,
+ err.format(self.__received_packets, self.__received_bytes),
+ )
+ self.__received_bytes_overflow = 0
+ self.__received_packets_overflow = 0
+ self._trigger_rekey()
+
+ cmd = byte_ord(payload[0])
+ if cmd in MSG_NAMES:
+ cmd_name = MSG_NAMES[cmd]
+ else:
+ cmd_name = "${:x}".format(cmd)
+ if self.__dump_packets:
+ self._log(
+ DEBUG,
+ "Read packet <{}>, length {}".format(cmd_name, len(payload)),
+ )
+ return cmd, msg
+
+ # ...protected...
+
+ def _log(self, level, msg):
+ if self.__logger is None:
+ return
+ if issubclass(type(msg), list):
+ for m in msg:
+ self.__logger.log(level, m)
+ else:
+ self.__logger.log(level, msg)
+
+ def _check_keepalive(self):
+ if (
+ not self.__keepalive_interval
+ or not self.__block_engine_out
+ or self.__need_rekey
+ ):
+ # wait till we're encrypting, and not in the middle of rekeying
+ return
+ now = time.time()
+ if now > self.__keepalive_last + self.__keepalive_interval:
+ self.__keepalive_callback()
+ self.__keepalive_last = now
+
+ def _read_timeout(self, timeout):
+ start = time.time()
+ while True:
+ try:
+ x = self.__socket.recv(128)
+ if len(x) == 0:
+ raise EOFError()
+ break
+ except socket.timeout:
+ pass
+ if self.__closed:
+ raise EOFError()
+ now = time.time()
+ if now - start >= timeout:
+ raise socket.timeout()
+ return x
+
+ def _build_packet(self, payload):
+ # pad up at least 4 bytes, to nearest block-size (usually 8)
+ bsize = self.__block_size_out
+ # do not include payload length in computations for padding in EtM mode
+ # (payload length won't be encrypted)
+ addlen = 4 if self.__etm_out else 8
+ padding = 3 + bsize - ((len(payload) + addlen) % bsize)
+ packet = struct.pack(">IB", len(payload) + padding + 1, padding)
+ packet += payload
+ if self.__sdctr_out or self.__block_engine_out is None:
+ # cute trick i caught openssh doing: if we're not encrypting or
+ # SDCTR mode (RFC4344),
+ # don't waste random bytes for the padding
+ packet += zero_byte * padding
+ else:
+ packet += os.urandom(padding)
+ return packet
+
+ def _trigger_rekey(self):
+ # outside code should check for this flag
+ self.__need_rekey = True
diff --git a/paramiko/pipe.py b/paramiko/pipe.py
index 0b740739..65944fad 100644
--- a/paramiko/pipe.py
+++ b/paramiko/pipe.py
@@ -1,3 +1,21 @@
+# Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
"""
Abstraction of a one-way pipe where the read end can be used in
`select.select`. Normally this is trivial, but Windows makes it nearly
@@ -6,19 +24,52 @@ impossible.
The pipe acts like an Event, which can be set or cleared. When set, the pipe
will trigger as readable in `select <select.select>`.
"""
+
import sys
import os
import socket
-class PosixPipe:
+def make_pipe():
+ if sys.platform[:3] != "win":
+ p = PosixPipe()
+ else:
+ p = WindowsPipe()
+ return p
+
+class PosixPipe:
def __init__(self):
self._rfd, self._wfd = os.pipe()
self._set = False
self._forever = False
self._closed = False
+ def close(self):
+ os.close(self._rfd)
+ os.close(self._wfd)
+ # used for unit tests:
+ self._closed = True
+
+ def fileno(self):
+ return self._rfd
+
+ def clear(self):
+ if not self._set or self._forever:
+ return
+ os.read(self._rfd, 1)
+ self._set = False
+
+ def set(self):
+ if self._set or self._closed:
+ return
+ self._set = True
+ os.write(self._wfd, b"*")
+
+ def set_forever(self):
+ self._forever = True
+ self.set()
+
class WindowsPipe:
"""
@@ -28,24 +79,61 @@ class WindowsPipe:
def __init__(self):
serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- serv.bind(('127.0.0.1', 0))
+ serv.bind(("127.0.0.1", 0))
serv.listen(1)
+
+ # need to save sockets in _rsock/_wsock so they don't get closed
self._rsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- self._rsock.connect(('127.0.0.1', serv.getsockname()[1]))
+ self._rsock.connect(("127.0.0.1", serv.getsockname()[1]))
+
self._wsock, addr = serv.accept()
serv.close()
self._set = False
self._forever = False
self._closed = False
+ def close(self):
+ self._rsock.close()
+ self._wsock.close()
+ # used for unit tests:
+ self._closed = True
-class OrPipe:
+ def fileno(self):
+ return self._rsock.fileno()
+
+ def clear(self):
+ if not self._set or self._forever:
+ return
+ self._rsock.recv(1)
+ self._set = False
+
+ def set(self):
+ if self._set or self._closed:
+ return
+ self._set = True
+ self._wsock.send(b"*")
+
+ def set_forever(self):
+ self._forever = True
+ self.set()
+
+class OrPipe:
def __init__(self, pipe):
self._set = False
self._partner = None
self._pipe = pipe
+ def set(self):
+ self._set = True
+ if not self._partner._set:
+ self._pipe.set()
+
+ def clear(self):
+ self._set = False
+ if not self._partner._set:
+ self._pipe.clear()
+
def make_or_pipe(pipe):
"""
@@ -53,4 +141,8 @@ def make_or_pipe(pipe):
affect the real pipe. if either returned pipe is set, the wrapped pipe
is set. when both are cleared, the wrapped pipe is cleared.
"""
- pass
+ p1 = OrPipe(pipe)
+ p2 = OrPipe(pipe)
+ p1._partner = p2
+ p2._partner = p1
+ return p1, p2
diff --git a/paramiko/pkey.py b/paramiko/pkey.py
index 69923124..f0b2d6d4 100644
--- a/paramiko/pkey.py
+++ b/paramiko/pkey.py
@@ -1,6 +1,25 @@
+# Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
"""
Common API for all public keys.
"""
+
import base64
from base64 import encodebytes, decodebytes
from binascii import unhexlify
@@ -9,21 +28,51 @@ from pathlib import Path
from hashlib import md5, sha256
import re
import struct
+
import bcrypt
+
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.ciphers import algorithms, modes, Cipher
from cryptography.hazmat.primitives import asymmetric
+
from paramiko import util
from paramiko.util import u, b
from paramiko.common import o600
from paramiko.ssh_exception import SSHException, PasswordRequiredException
from paramiko.message import Message
+
+
+# TripleDES is moving from `cryptography.hazmat.primitives.ciphers.algorithms`
+# in cryptography>=43.0.0 to `cryptography.hazmat.decrepit.ciphers.algorithms`
+# It will be removed from `cryptography.hazmat.primitives.ciphers.algorithms`
+# in cryptography==48.0.0.
+#
+# Source References:
+# - https://github.com/pyca/cryptography/commit/722a6393e61b3ac
+# - https://github.com/pyca/cryptography/pull/11407/files
try:
from cryptography.hazmat.decrepit.ciphers.algorithms import TripleDES
except ImportError:
from cryptography.hazmat.primitives.ciphers.algorithms import TripleDES
-OPENSSH_AUTH_MAGIC = b'openssh-key-v1\x00'
+
+
+OPENSSH_AUTH_MAGIC = b"openssh-key-v1\x00"
+
+
+def _unpad_openssh(data):
+ # At the moment, this is only used for unpadding private keys on disk. This
+ # really ought to be made constant time (possibly by upstreaming this logic
+ # into pyca/cryptography).
+ padding_length = data[-1]
+ if 0x20 <= padding_length < 0x7F:
+ return data # no padding, last byte part comment (printable ascii)
+ if padding_length > 15:
+ raise SSHException("Invalid key")
+ for i in range(padding_length):
+ if data[i - padding_length] != i + 1:
+ raise SSHException("Invalid key")
+ return data[:-padding_length]
class UnknownKeyType(Exception):
@@ -36,9 +85,7 @@ class UnknownKeyType(Exception):
self.key_bytes = key_bytes
def __str__(self):
- return (
- f'UnknownKeyType(type={self.key_type!r}, bytes=<{len(self.key_bytes)}>)'
- )
+ return f"UnknownKeyType(type={self.key_type!r}, bytes=<{len(self.key_bytes)}>)" # noqa
class PKey:
@@ -48,16 +95,34 @@ class PKey:
Also includes some "meta" level convenience constructors such as
`.from_type_string`.
"""
- _CIPHER_TABLE = {'AES-128-CBC': {'cipher': algorithms.AES, 'keysize':
- 16, 'blocksize': 16, 'mode': modes.CBC}, 'AES-256-CBC': {'cipher':
- algorithms.AES, 'keysize': 32, 'blocksize': 16, 'mode': modes.CBC},
- 'DES-EDE3-CBC': {'cipher': TripleDES, 'keysize': 24, 'blocksize': 8,
- 'mode': modes.CBC}}
+
+ # known encryption types for private key files:
+ _CIPHER_TABLE = {
+ "AES-128-CBC": {
+ "cipher": algorithms.AES,
+ "keysize": 16,
+ "blocksize": 16,
+ "mode": modes.CBC,
+ },
+ "AES-256-CBC": {
+ "cipher": algorithms.AES,
+ "keysize": 32,
+ "blocksize": 16,
+ "mode": modes.CBC,
+ },
+ "DES-EDE3-CBC": {
+ "cipher": TripleDES,
+ "keysize": 24,
+ "blocksize": 8,
+ "mode": modes.CBC,
+ },
+ }
_PRIVATE_KEY_FORMAT_ORIGINAL = 1
_PRIVATE_KEY_FORMAT_OPENSSH = 2
BEGIN_TAG = re.compile(
- '^-{5}BEGIN (RSA|DSA|EC|OPENSSH) PRIVATE KEY-{5}\\s*$')
- END_TAG = re.compile('^-{5}END (RSA|DSA|EC|OPENSSH) PRIVATE KEY-{5}\\s*$')
+ r"^-{5}BEGIN (RSA|DSA|EC|OPENSSH) PRIVATE KEY-{5}\s*$"
+ )
+ END_TAG = re.compile(r"^-{5}END (RSA|DSA|EC|OPENSSH) PRIVATE KEY-{5}\s*$")
@staticmethod
def from_path(path, passphrase=None):
@@ -74,7 +139,64 @@ class PKey:
.. versionadded:: 3.2
"""
- pass
+ # TODO: make sure sphinx is reading Path right in param list...
+
+ # Lazy import to avoid circular import issues
+ from paramiko import DSSKey, RSAKey, Ed25519Key, ECDSAKey
+
+ # Normalize to string, as cert suffix isn't quite an extension, so
+ # pathlib isn't useful for this.
+ path = str(path)
+
+ # Sort out cert vs key, i.e. it is 'legal' to hand this kind of API
+ # /either/ the key /or/ the cert, when there is a key/cert pair.
+ cert_suffix = "-cert.pub"
+ if str(path).endswith(cert_suffix):
+ key_path = path[: -len(cert_suffix)]
+ cert_path = path
+ else:
+ key_path = path
+ cert_path = path + cert_suffix
+
+ key_path = Path(key_path).expanduser()
+ cert_path = Path(cert_path).expanduser()
+
+ data = key_path.read_bytes()
+ # Like OpenSSH, try modern/OpenSSH-specific key load first
+ try:
+ loaded = serialization.load_ssh_private_key(
+ data=data, password=passphrase
+ )
+ # Then fall back to assuming legacy PEM type
+ except ValueError:
+ loaded = serialization.load_pem_private_key(
+ data=data, password=passphrase
+ )
+ # TODO Python 3.10: match statement? (NOTE: we cannot use a dict
+ # because the results from the loader are literal backend, eg openssl,
+ # private classes, so isinstance tests work but exact 'x class is y'
+ # tests will not work)
+ # TODO: leverage already-parsed/math'd obj to avoid duplicate cpu
+ # cycles? seemingly requires most of our key subclasses to be rewritten
+ # to be cryptography-object-forward. this is still likely faster than
+ # the old SSHClient code that just tried instantiating every class!
+ key_class = None
+ if isinstance(loaded, asymmetric.dsa.DSAPrivateKey):
+ key_class = DSSKey
+ elif isinstance(loaded, asymmetric.rsa.RSAPrivateKey):
+ key_class = RSAKey
+ elif isinstance(loaded, asymmetric.ed25519.Ed25519PrivateKey):
+ key_class = Ed25519Key
+ elif isinstance(loaded, asymmetric.ec.EllipticCurvePrivateKey):
+ key_class = ECDSAKey
+ else:
+ raise UnknownKeyType(key_bytes=data, key_type=loaded.__class__)
+ with key_path.open() as fd:
+ key = key_class.from_private_key(fd, password=passphrase)
+ if cert_path.exists():
+ # load_certificate can take Message, path-str, or value-str
+ key.load_certificate(str(cert_path))
+ return key
@staticmethod
def from_type_string(key_type, key_bytes):
@@ -98,7 +220,13 @@ class PKey:
.. versionadded:: 3.2
"""
- pass
+ from paramiko import key_classes
+
+ for key_class in key_classes:
+ if key_type in key_class.identifiers():
+ # TODO: needs to passthru things like passphrase
+ return key_class(data=key_bytes)
+ raise UnknownKeyType(key_type=key_type, key_bytes=key_bytes)
@classmethod
def identifiers(cls):
@@ -109,8 +237,14 @@ class PKey:
implementation suffices; see `.ECDSAKey` for one example of an
override.
"""
- pass
+ return [cls.name]
+ # TODO 4.0: make this and subclasses consistent, some of our own
+ # classmethods even assume kwargs we don't define!
+ # TODO 4.0: prob also raise NotImplementedError instead of pass'ing; the
+ # contract is pretty obviously that you need to handle msg/data/filename
+ # appropriately. (If 'pass' is a concession to testing, see about doing the
+ # work to fix the tests instead)
def __init__(self, msg=None, data=None):
"""
Create a new instance of this public key type. If ``msg`` is given,
@@ -129,21 +263,27 @@ class PKey:
"""
pass
+ # TODO: arguably this might want to be __str__ instead? ehh
+ # TODO: ditto the interplay between showing class name (currently we just
+ # say PKey writ large) and algorithm (usually == class name, but not
+ # always, also sometimes shows certificate-ness)
+ # TODO: if we do change it, we also want to tweak eg AgentKey, as it
+ # currently displays agent-ness with a suffix
def __repr__(self):
- comment = ''
- if hasattr(self, 'comment') and self.comment:
- comment = f', comment={self.comment!r}'
- return (
- f'PKey(alg={self.algorithm_name}, bits={self.get_bits()}, fp={self.fingerprint}{comment})'
- )
+ comment = ""
+ # Works for AgentKey, may work for others?
+ if hasattr(self, "comment") and self.comment:
+ comment = f", comment={self.comment!r}"
+ return f"PKey(alg={self.algorithm_name}, bits={self.get_bits()}, fp={self.fingerprint}{comment})" # noqa
+ # TODO 4.0: just merge into __bytes__ (everywhere)
def asbytes(self):
"""
Return a string of an SSH `.Message` made up of the public part(s) of
this key. This string is suitable for passing to `__init__` to
re-create the key object later.
"""
- pass
+ return bytes()
def __bytes__(self):
return self.asbytes()
@@ -154,6 +294,10 @@ class PKey:
def __hash__(self):
return hash(self._fields)
+ @property
+ def _fields(self):
+ raise NotImplementedError
+
def get_name(self):
"""
Return the name of this private key implementation.
@@ -162,7 +306,7 @@ class PKey:
name of this private key type, in SSH terminology, as a `str` (for
example, ``"ssh-rsa"``).
"""
- pass
+ return ""
@property
def algorithm_name(self):
@@ -172,7 +316,17 @@ class PKey:
Similar to `get_name`, but aimed at pure algorithm name instead of SSH
protocol field value.
"""
- pass
+ # Nuke the leading 'ssh-'
+ # TODO in Python 3.9: use .removeprefix()
+ name = self.get_name().replace("ssh-", "")
+ # Trim any cert suffix (but leave the -cert, as OpenSSH does)
+ cert_tail = "-cert-v01@openssh.com"
+ if cert_tail in name:
+ name = name.replace(cert_tail, "-cert")
+ # Nuke any eg ECDSA suffix, OpenSSH does basically this too.
+ else:
+ name = name.split("-")[0]
+ return name.upper()
def get_bits(self):
"""
@@ -181,14 +335,16 @@ class PKey:
:return: bits in the key (as an `int`)
"""
- pass
+ # TODO 4.0: raise NotImplementedError, 0 is unlikely to ever be
+ # _correct_ and nothing in the critical path seems to use this.
+ return 0
def can_sign(self):
"""
Return ``True`` if this key has the private part necessary for signing
data.
"""
- pass
+ return False
def get_fingerprint(self):
"""
@@ -199,7 +355,7 @@ class PKey:
a 16-byte `string <str>` (binary) of the MD5 fingerprint, in SSH
format.
"""
- pass
+ return md5(self.asbytes()).digest()
@property
def fingerprint(self):
@@ -210,7 +366,11 @@ class PKey:
.. versionadded:: 3.2
"""
- pass
+ hashy = sha256(bytes(self))
+ hash_name = hashy.name.upper()
+ b64ed = encodebytes(hashy.digest())
+ cleaned = u(b64ed).strip().rstrip("=") # yes, OpenSSH does this too!
+ return f"{hash_name}:{cleaned}"
def get_base64(self):
"""
@@ -220,7 +380,7 @@ class PKey:
:return: a base64 `string <str>` containing the public part of the key.
"""
- pass
+ return u(encodebytes(self.asbytes())).replace("\n", "")
def sign_ssh_data(self, data, algorithm=None):
"""
@@ -237,7 +397,7 @@ class PKey:
.. versionchanged:: 2.9
Added the ``algorithm`` kwarg.
"""
- pass
+ return bytes()
def verify_ssh_sig(self, data, msg):
"""
@@ -249,7 +409,7 @@ class PKey:
:return:
``True`` if the signature verifies correctly; ``False`` otherwise.
"""
- pass
+ return False
@classmethod
def from_private_key_file(cls, filename, password=None):
@@ -272,7 +432,8 @@ class PKey:
encrypted, and ``password`` is ``None``
:raises: `.SSHException` -- if the key file is invalid
"""
- pass
+ key = cls(filename=filename, password=password)
+ return key
@classmethod
def from_private_key(cls, file_obj, password=None):
@@ -292,7 +453,8 @@ class PKey:
if the private key file is encrypted, and ``password`` is ``None``
:raises: `.SSHException` -- if the key file is invalid
"""
- pass
+ key = cls(file_obj=file_obj, password=password)
+ return key
def write_private_key_file(self, filename, password=None):
"""
@@ -306,7 +468,7 @@ class PKey:
:raises: ``IOError`` -- if there was an error writing the file
:raises: `.SSHException` -- if the key is invalid
"""
- pass
+ raise Exception("Not implemented in PKey")
def write_private_key(self, file_obj, password=None):
"""
@@ -319,7 +481,8 @@ class PKey:
:raises: ``IOError`` -- if there was an error writing to the file
:raises: `.SSHException` -- if the key is invalid
"""
- pass
+ # TODO 4.0: NotImplementedError (plus everywhere else in here)
+ raise Exception("Not implemented in PKey")
def _read_private_key_file(self, tag, filename, password=None):
"""
@@ -342,7 +505,97 @@ class PKey:
encrypted, and ``password`` is ``None``.
:raises: `.SSHException` -- if the key file is invalid.
"""
- pass
+ with open(filename, "r") as f:
+ data = self._read_private_key(tag, f, password)
+ return data
+
+ def _read_private_key(self, tag, f, password=None):
+ lines = f.readlines()
+ if not lines:
+ raise SSHException("no lines in {} private key file".format(tag))
+
+ # find the BEGIN tag
+ start = 0
+ m = self.BEGIN_TAG.match(lines[start])
+ line_range = len(lines) - 1
+ while start < line_range and not m:
+ start += 1
+ m = self.BEGIN_TAG.match(lines[start])
+ start += 1
+ keytype = m.group(1) if m else None
+ if start >= len(lines) or keytype is None:
+ raise SSHException("not a valid {} private key file".format(tag))
+
+ # find the END tag
+ end = start
+ m = self.END_TAG.match(lines[end])
+ while end < line_range and not m:
+ end += 1
+ m = self.END_TAG.match(lines[end])
+
+ if keytype == tag:
+ data = self._read_private_key_pem(lines, end, password)
+ pkformat = self._PRIVATE_KEY_FORMAT_ORIGINAL
+ elif keytype == "OPENSSH":
+ data = self._read_private_key_openssh(lines[start:end], password)
+ pkformat = self._PRIVATE_KEY_FORMAT_OPENSSH
+ else:
+ raise SSHException(
+ "encountered {} key, expected {} key".format(keytype, tag)
+ )
+
+ return pkformat, data
+
+ def _got_bad_key_format_id(self, id_):
+ err = "{}._read_private_key() spat out an unknown key format id '{}'"
+ raise SSHException(err.format(self.__class__.__name__, id_))
+
+ def _read_private_key_pem(self, lines, end, password):
+ start = 0
+ # parse any headers first
+ headers = {}
+ start += 1
+ while start < len(lines):
+ line = lines[start].split(": ")
+ if len(line) == 1:
+ break
+ headers[line[0].lower()] = line[1].strip()
+ start += 1
+ # if we trudged to the end of the file, just try to cope.
+ try:
+ data = decodebytes(b("".join(lines[start:end])))
+ except base64.binascii.Error as e:
+ raise SSHException("base64 decoding error: {}".format(e))
+ if "proc-type" not in headers:
+ # unencryped: done
+ return data
+ # encrypted keyfile: will need a password
+ proc_type = headers["proc-type"]
+ if proc_type != "4,ENCRYPTED":
+ raise SSHException(
+ 'Unknown private key structure "{}"'.format(proc_type)
+ )
+ try:
+ encryption_type, saltstr = headers["dek-info"].split(",")
+ except:
+ raise SSHException("Can't parse DEK-info in private key file")
+ if encryption_type not in self._CIPHER_TABLE:
+ raise SSHException(
+ 'Unknown private key cipher "{}"'.format(encryption_type)
+ )
+ # if no password was passed in,
+ # raise an exception pointing out that we need one
+ if password is None:
+ raise PasswordRequiredException("Private key file is encrypted")
+ cipher = self._CIPHER_TABLE[encryption_type]["cipher"]
+ keysize = self._CIPHER_TABLE[encryption_type]["keysize"]
+ mode = self._CIPHER_TABLE[encryption_type]["mode"]
+ salt = unhexlify(b(saltstr))
+ key = util.generate_key_bytes(md5, salt, password, keysize)
+ decryptor = Cipher(
+ cipher(key), mode(salt), backend=default_backend()
+ ).decryptor()
+ return decryptor.update(data) + decryptor.finalize()
def _read_private_key_openssh(self, lines, password):
"""
@@ -351,7 +604,84 @@ class PKey:
Reference:
https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.key
"""
- pass
+ try:
+ data = decodebytes(b("".join(lines)))
+ except base64.binascii.Error as e:
+ raise SSHException("base64 decoding error: {}".format(e))
+
+ # read data struct
+ auth_magic = data[:15]
+ if auth_magic != OPENSSH_AUTH_MAGIC:
+ raise SSHException("unexpected OpenSSH key header encountered")
+
+ cstruct = self._uint32_cstruct_unpack(data[15:], "sssur")
+ cipher, kdfname, kdf_options, num_pubkeys, remainder = cstruct
+ # For now, just support 1 key.
+ if num_pubkeys > 1:
+ raise SSHException(
+ "unsupported: private keyfile has multiple keys"
+ )
+ pubkey, privkey_blob = self._uint32_cstruct_unpack(remainder, "ss")
+
+ if kdfname == b("bcrypt"):
+ if cipher == b("aes256-cbc"):
+ mode = modes.CBC
+ elif cipher == b("aes256-ctr"):
+ mode = modes.CTR
+ else:
+ raise SSHException(
+ "unknown cipher `{}` used in private key file".format(
+ cipher.decode("utf-8")
+ )
+ )
+ # Encrypted private key.
+ # If no password was passed in, raise an exception pointing
+ # out that we need one
+ if password is None:
+ raise PasswordRequiredException(
+ "private key file is encrypted"
+ )
+
+ # Unpack salt and rounds from kdfoptions
+ salt, rounds = self._uint32_cstruct_unpack(kdf_options, "su")
+
+ # run bcrypt kdf to derive key and iv/nonce (32 + 16 bytes)
+ key_iv = bcrypt.kdf(
+ b(password),
+ b(salt),
+ 48,
+ rounds,
+ # We can't control how many rounds are on disk, so no sense
+ # warning about it.
+ ignore_few_rounds=True,
+ )
+ key = key_iv[:32]
+ iv = key_iv[32:]
+
+ # decrypt private key blob
+ decryptor = Cipher(
+ algorithms.AES(key), mode(iv), default_backend()
+ ).decryptor()
+ decrypted_privkey = decryptor.update(privkey_blob)
+ decrypted_privkey += decryptor.finalize()
+ elif cipher == b("none") and kdfname == b("none"):
+ # Unencrypted private key
+ decrypted_privkey = privkey_blob
+ else:
+ raise SSHException(
+ "unknown cipher or kdf used in private key file"
+ )
+
+ # Unpack private key and verify checkints
+ cstruct = self._uint32_cstruct_unpack(decrypted_privkey, "uusr")
+ checkint1, checkint2, keytype, keydata = cstruct
+
+ if checkint1 != checkint2:
+ raise SSHException(
+ "OpenSSH private key file checkints do not match"
+ )
+
+ return _unpad_openssh(keydata)
def _uint32_cstruct_unpack(self, data, strformat):
"""
@@ -366,7 +696,41 @@ class PKey:
u - denotes a 32-bit unsigned integer
r - the remainder of the input string, returned as a string
"""
- pass
+ arr = []
+ idx = 0
+ try:
+ for f in strformat:
+ if f == "s":
+ # string
+ s_size = struct.unpack(">L", data[idx : idx + 4])[0]
+ idx += 4
+ s = data[idx : idx + s_size]
+ idx += s_size
+ arr.append(s)
+ if f == "i":
+ # long integer
+ s_size = struct.unpack(">L", data[idx : idx + 4])[0]
+ idx += 4
+ s = data[idx : idx + s_size]
+ idx += s_size
+ i = util.inflate_long(s, True)
+ arr.append(i)
+ elif f == "u":
+ # 32-bit unsigned int
+ u = struct.unpack(">L", data[idx : idx + 4])[0]
+ idx += 4
+ arr.append(u)
+ elif f == "r":
+ # remainder as string
+ s = data[idx:]
+ arr.append(s)
+ break
+ except Exception as e:
+ # PKey-consuming code frequently wants to save-and-skip-over issues
+ # with loading keys, and uses SSHException as the (really friggin
+ # awful) signal for this. So for now...we do this.
+ raise SSHException(str(e))
+ return tuple(arr)
def _write_private_key_file(self, filename, key, format, password=None):
"""
@@ -383,7 +747,36 @@ class PKey:
:raises: ``IOError`` -- if there was an error writing the file.
"""
- pass
+ # Ensure that we create new key files directly with a user-only mode,
+ # instead of opening, writing, then chmodding, which leaves us open to
+ # CVE-2022-24302.
+ with os.fdopen(
+ os.open(
+ filename,
+ # NOTE: O_TRUNC is a noop on new files, and O_CREAT is a noop
+ # on existing files, so using all 3 in both cases is fine.
+ flags=os.O_WRONLY | os.O_TRUNC | os.O_CREAT,
+ # Ditto the use of the 'mode' argument; it should be safe to
+ # give even for existing files (though it will not act like a
+ # chmod in that case).
+ mode=o600,
+ ),
+ # Yea, you still gotta inform the FLO that it is in "write" mode.
+ "w",
+ ) as f:
+ self._write_private_key(f, key, format, password=password)
+
+ def _write_private_key(self, f, key, format, password=None):
+ if password is None:
+ encryption = serialization.NoEncryption()
+ else:
+ encryption = serialization.BestAvailableEncryption(b(password))
+
+ f.write(
+ key.private_bytes(
+ serialization.Encoding.PEM, format, encryption
+ ).decode()
+ )
def _check_type_and_load_cert(self, msg, key_type, cert_type):
"""
@@ -396,7 +789,43 @@ class PKey:
The obtained key type is returned for classes which need to know what
it was (e.g. ECDSA.)
"""
- pass
+ # Normalization; most classes have a single key type and give a string,
+ # but eg ECDSA is a 1:N mapping.
+ key_types = key_type
+ cert_types = cert_type
+ if isinstance(key_type, str):
+ key_types = [key_types]
+ if isinstance(cert_types, str):
+ cert_types = [cert_types]
+ # Can't do much with no message, that should've been handled elsewhere
+ if msg is None:
+ raise SSHException("Key object may not be empty")
+ # First field is always key type, in either kind of object. (make sure
+ # we rewind before grabbing it - sometimes caller had to do their own
+ # introspection first!)
+ msg.rewind()
+ type_ = msg.get_text()
+ # Regular public key - nothing special to do besides the implicit
+ # type check.
+ if type_ in key_types:
+ pass
+ # OpenSSH-compatible certificate - store full copy as .public_blob
+ # (so signing works correctly) and then fast-forward past the
+ # nonce.
+ elif type_ in cert_types:
+ # This seems the cleanest way to 'clone' an already-being-read
+ # message; they're *IO objects at heart and their .getvalue()
+ # always returns the full value regardless of pointer position.
+ self.load_certificate(Message(msg.asbytes()))
+ # Read out nonce as it comes before the public numbers - our caller
+ # is likely going to use the (only borrowed by us, not owned)
+ # 'msg' object for loading those numbers right after this.
+ # TODO: usefully interpret it & other non-public-number fields
+ # (requires going back into per-type subclasses.)
+ msg.get_string()
+ else:
+ err = "Invalid key (class: {}, data type: {}"
+ raise SSHException(err.format(self.__class__.__name__, type_))
def load_certificate(self, value):
"""
@@ -417,9 +846,25 @@ class PKey:
that is for the server to decide if it is good enough to authenticate
successfully.
"""
- pass
-
-
+ if isinstance(value, Message):
+ constructor = "from_message"
+ elif os.path.isfile(value):
+ constructor = "from_file"
+ else:
+ constructor = "from_string"
+ blob = getattr(PublicBlob, constructor)(value)
+ if not blob.key_type.startswith(self.get_name()):
+ err = "PublicBlob type {} incompatible with key type {}"
+ raise ValueError(err.format(blob.key_type, self.get_name()))
+ self.public_blob = blob
+
+
+# General construct for an OpenSSH style Public Key blob
+# readable from a one-line file of the format:
+# <key-name> <base64-blob> [<comment>]
+# Of little value in the case of standard public keys
+# {ssh-rsa, ssh-dss, ssh-ecdsa, ssh-ed25519}, but should
+# provide rudimentary support for {*-cert.v01}
class PublicBlob:
"""
OpenSSH plain public key or OpenSSH signed public key (certificate).
@@ -451,14 +896,36 @@ class PublicBlob:
"""
Create a public blob from a ``-cert.pub``-style file on disk.
"""
- pass
+ with open(filename) as f:
+ string = f.read()
+ return cls.from_string(string)
@classmethod
def from_string(cls, string):
"""
Create a public blob from a ``-cert.pub``-style string.
"""
- pass
+ fields = string.split(None, 2)
+ if len(fields) < 2:
+ msg = "Not enough fields for public blob: {}"
+ raise ValueError(msg.format(fields))
+ key_type = fields[0]
+ key_blob = decodebytes(b(fields[1]))
+ try:
+ comment = fields[2].strip()
+ except IndexError:
+ comment = None
+ # Verify that the blob message first (string) field matches the
+ # key_type
+ m = Message(key_blob)
+ blob_type = m.get_text()
+ if blob_type != key_type:
+ deets = "key type={!r}, but blob type={!r}".format(
+ key_type, blob_type
+ )
+ raise ValueError("Invalid PublicBlob contents: {}".format(deets))
+ # All good? All good.
+ return cls(type_=key_type, blob=key_blob, comment=comment)
@classmethod
def from_message(cls, message):
@@ -468,15 +935,17 @@ class PublicBlob:
Specifically, a cert-bearing pubkey auth packet, because by definition
OpenSSH-style certificates 'are' their own network representation."
"""
- pass
+ type_ = message.get_text()
+ return cls(type_=type_, blob=message.asbytes())
def __str__(self):
- ret = '{} public key/certificate'.format(self.key_type)
+ ret = "{} public key/certificate".format(self.key_type)
if self.comment:
- ret += '- {}'.format(self.comment)
+ ret += "- {}".format(self.comment)
return ret
def __eq__(self, other):
+ # Just piggyback on Message/BytesIO, since both of these should be one.
return self and other and self.key_blob == other.key_blob
def __ne__(self, other):
diff --git a/paramiko/primes.py b/paramiko/primes.py
index c0ded8f9..663c58ed 100644
--- a/paramiko/primes.py
+++ b/paramiko/primes.py
@@ -1,7 +1,27 @@
+# Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
"""
Utility functions for dealing with primes.
"""
+
import os
+
from paramiko import util
from paramiko.common import byte_mask
from paramiko.ssh_exception import SSHException
@@ -9,7 +29,24 @@ from paramiko.ssh_exception import SSHException
def _roll_random(n):
"""returns a random # from 0 to N-1"""
- pass
+ bits = util.bit_length(n - 1)
+ byte_count = (bits + 7) // 8
+ hbyte_mask = pow(2, bits % 8) - 1
+
+ # so here's the plan:
+ # we fetch as many random bits as we'd need to fit N-1, and if the
+ # generated number is >= N, we try again. in the worst case (N-1 is a
+ # power of 2), we have slightly better than 50% odds of getting one that
+ # fits, so i can't guarantee that this loop will ever finish, but the odds
+ # of it looping forever should be infinitesimal.
+ while True:
+ x = os.urandom(byte_count)
+ if hbyte_mask > 0:
+ x = byte_mask(x[0], hbyte_mask) + x[1:]
+ num = util.inflate_long(x, 1)
+ if num < n:
+ break
+ return num
class ModulusPack:
@@ -19,11 +56,93 @@ class ModulusPack:
"""
def __init__(self):
+ # pack is a hash of: bits -> [ (generator, modulus) ... ]
self.pack = {}
self.discarded = []
+ def _parse_modulus(self, line):
+ (
+ timestamp,
+ mod_type,
+ tests,
+ tries,
+ size,
+ generator,
+ modulus,
+ ) = line.split()
+ mod_type = int(mod_type)
+ tests = int(tests)
+ tries = int(tries)
+ size = int(size)
+ generator = int(generator)
+ modulus = int(modulus, 16)
+
+ # weed out primes that aren't at least:
+ # type 2 (meets basic structural requirements)
+ # test 4 (more than just a small-prime sieve)
+ # tries < 100 if test & 4 (at least 100 tries of miller-rabin)
+ if (
+ mod_type < 2
+ or tests < 4
+ or (tests & 4 and tests < 8 and tries < 100)
+ ):
+ self.discarded.append(
+ (modulus, "does not meet basic requirements")
+ )
+ return
+ if generator == 0:
+ generator = 2
+
+ # there's a bug in the ssh "moduli" file (yeah, i know: shock! dismay!
+ # call cnn!) where it understates the bit lengths of these primes by 1.
+ # this is okay.
+ bl = util.bit_length(modulus)
+ if (bl != size) and (bl != size + 1):
+ self.discarded.append(
+ (modulus, "incorrectly reported bit length {}".format(size))
+ )
+ return
+ if bl not in self.pack:
+ self.pack[bl] = []
+ self.pack[bl].append((generator, modulus))
+
def read_file(self, filename):
"""
:raises IOError: passed from any file operations that fail.
"""
- pass
+ self.pack = {}
+ with open(filename, "r") as f:
+ for line in f:
+ line = line.strip()
+ if (len(line) == 0) or (line[0] == "#"):
+ continue
+ try:
+ self._parse_modulus(line)
+ except:
+ continue
+
+ def get_modulus(self, min, prefer, max):
+ bitsizes = sorted(self.pack.keys())
+ if len(bitsizes) == 0:
+ raise SSHException("no moduli available")
+ good = -1
+ # find nearest bitsize >= preferred
+ for b in bitsizes:
+ if (b >= prefer) and (b <= max) and (b < good or good == -1):
+ good = b
+ # if that failed, find greatest bitsize >= min
+ if good == -1:
+ for b in bitsizes:
+ if (b >= min) and (b <= max) and (b > good):
+ good = b
+ if good == -1:
+ # their entire (min, max) range has no intersection with our range.
+ # if their range is below ours, pick the smallest. otherwise pick
+ # the largest. it'll be out of their range requirement either way,
+ # but we'll be sending them the closest one we have.
+ good = bitsizes[0]
+ if min > good:
+ good = bitsizes[-1]
+ # now pick a random modulus of this bitsize
+ n = _roll_random(len(self.pack[good]))
+ return self.pack[good][n]
diff --git a/paramiko/proxy.py b/paramiko/proxy.py
index 2d1ebe34..f7609c98 100644
--- a/paramiko/proxy.py
+++ b/paramiko/proxy.py
@@ -1,14 +1,37 @@
+# Copyright (C) 2012 Yipit, Inc <coders@yipit.com>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
+
import os
import shlex
import signal
from select import select
import socket
import time
+
+# Try-and-ignore import so platforms w/o subprocess (eg Google App Engine) can
+# still import paramiko.
subprocess, subprocess_import_error = None, None
try:
import subprocess
except ImportError as e:
subprocess_import_error = e
+
from paramiko.ssh_exception import ProxyCommandFailure
from paramiko.util import ClosingContextManager
@@ -36,8 +59,13 @@ class ProxyCommand(ClosingContextManager):
if subprocess is None:
raise subprocess_import_error
self.cmd = shlex.split(command_line)
- self.process = subprocess.Popen(self.cmd, stdin=subprocess.PIPE,
- stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=0)
+ self.process = subprocess.Popen(
+ self.cmd,
+ stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ bufsize=0,
+ )
self.timeout = None
def send(self, content):
@@ -47,7 +75,15 @@ class ProxyCommand(ClosingContextManager):
:param str content: string to be sent to the forked command
"""
- pass
+ try:
+ self.process.stdin.write(content)
+ except IOError as e:
+ # There was a problem with the child process. It probably
+ # died and we can't proceed. The best option here is to
+ # raise an exception informing the user that the informed
+ # ProxyCommand is not working.
+ raise ProxyCommandFailure(" ".join(self.cmd), e.strerror)
+ return len(content)
def recv(self, size):
"""
@@ -57,4 +93,42 @@ class ProxyCommand(ClosingContextManager):
:return: the string of bytes read, which may be shorter than requested
"""
- pass
+ try:
+ buffer = b""
+ start = time.time()
+ while len(buffer) < size:
+ select_timeout = None
+ if self.timeout is not None:
+ elapsed = time.time() - start
+ if elapsed >= self.timeout:
+ raise socket.timeout()
+ select_timeout = self.timeout - elapsed
+
+ r, w, x = select([self.process.stdout], [], [], select_timeout)
+ if r and r[0] == self.process.stdout:
+ buffer += os.read(
+ self.process.stdout.fileno(), size - len(buffer)
+ )
+ return buffer
+ except socket.timeout:
+ if buffer:
+ # Don't raise socket.timeout, return partial result instead
+ return buffer
+ raise # socket.timeout is a subclass of IOError
+ except IOError as e:
+ raise ProxyCommandFailure(" ".join(self.cmd), e.strerror)
+
+ def close(self):
+ os.kill(self.process.pid, signal.SIGTERM)
+
+ @property
+ def closed(self):
+ return self.process.returncode is not None
+
+ @property
+ def _closed(self):
+ # Concession to Python 3 socket-like API
+ return self.closed
+
+ def settimeout(self, timeout):
+ self.timeout = timeout
diff --git a/paramiko/rsakey.py b/paramiko/rsakey.py
index 5e60a19c..b7ad3ce2 100644
--- a/paramiko/rsakey.py
+++ b/paramiko/rsakey.py
@@ -1,10 +1,30 @@
+# Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
"""
RSA keys.
"""
+
from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa, padding
+
from paramiko.message import Message
from paramiko.pkey import PKey
from paramiko.ssh_exception import SSHException
@@ -15,14 +35,26 @@ class RSAKey(PKey):
Representation of an RSA key which can be used to sign and verify SSH2
data.
"""
- name = 'ssh-rsa'
- HASHES = {'ssh-rsa': hashes.SHA1, 'ssh-rsa-cert-v01@openssh.com':
- hashes.SHA1, 'rsa-sha2-256': hashes.SHA256,
- 'rsa-sha2-256-cert-v01@openssh.com': hashes.SHA256, 'rsa-sha2-512':
- hashes.SHA512, 'rsa-sha2-512-cert-v01@openssh.com': hashes.SHA512}
-
- def __init__(self, msg=None, data=None, filename=None, password=None,
- key=None, file_obj=None):
+
+ name = "ssh-rsa"
+ HASHES = {
+ "ssh-rsa": hashes.SHA1,
+ "ssh-rsa-cert-v01@openssh.com": hashes.SHA1,
+ "rsa-sha2-256": hashes.SHA256,
+ "rsa-sha2-256-cert-v01@openssh.com": hashes.SHA256,
+ "rsa-sha2-512": hashes.SHA512,
+ "rsa-sha2-512-cert-v01@openssh.com": hashes.SHA512,
+ }
+
+ def __init__(
+ self,
+ msg=None,
+ data=None,
+ filename=None,
+ password=None,
+ key=None,
+ file_obj=None,
+ ):
self.key = None
self.public_blob = None
if file_obj is not None:
@@ -31,18 +63,117 @@ class RSAKey(PKey):
if filename is not None:
self._from_private_key_file(filename, password)
return
- if msg is None and data is not None:
+ if (msg is None) and (data is not None):
msg = Message(data)
if key is not None:
self.key = key
else:
- self._check_type_and_load_cert(msg=msg, key_type=self.name,
- cert_type='ssh-rsa-cert-v01@openssh.com')
- self.key = rsa.RSAPublicNumbers(e=msg.get_mpint(), n=msg.
- get_mpint()).public_key(default_backend())
+ self._check_type_and_load_cert(
+ msg=msg,
+ # NOTE: this does NOT change when using rsa2 signatures; it's
+ # purely about key loading, not exchange or verification
+ key_type=self.name,
+ cert_type="ssh-rsa-cert-v01@openssh.com",
+ )
+ self.key = rsa.RSAPublicNumbers(
+ e=msg.get_mpint(), n=msg.get_mpint()
+ ).public_key(default_backend())
+
+ @classmethod
+ def identifiers(cls):
+ return list(cls.HASHES.keys())
+
+ @property
+ def size(self):
+ return self.key.key_size
+
+ @property
+ def public_numbers(self):
+ if isinstance(self.key, rsa.RSAPrivateKey):
+ return self.key.private_numbers().public_numbers
+ else:
+ return self.key.public_numbers()
+
+ def asbytes(self):
+ m = Message()
+ m.add_string(self.name)
+ m.add_mpint(self.public_numbers.e)
+ m.add_mpint(self.public_numbers.n)
+ return m.asbytes()
def __str__(self):
- return self.asbytes().decode('utf8', errors='ignore')
+ # NOTE: see #853 to explain some legacy behavior.
+ # TODO 4.0: replace with a nice clean fingerprint display or something
+ return self.asbytes().decode("utf8", errors="ignore")
+
+ @property
+ def _fields(self):
+ return (self.get_name(), self.public_numbers.e, self.public_numbers.n)
+
+ def get_name(self):
+ return self.name
+
+ def get_bits(self):
+ return self.size
+
+ def can_sign(self):
+ return isinstance(self.key, rsa.RSAPrivateKey)
+
+ def sign_ssh_data(self, data, algorithm=None):
+ if algorithm is None:
+ algorithm = self.name
+ sig = self.key.sign(
+ data,
+ padding=padding.PKCS1v15(),
+ # HASHES being just a map from long identifier to either SHA1 or
+ # SHA256 - cert'ness is not truly relevant.
+ algorithm=self.HASHES[algorithm](),
+ )
+ m = Message()
+ # And here again, cert'ness is irrelevant, so it is stripped out.
+ m.add_string(algorithm.replace("-cert-v01@openssh.com", ""))
+ m.add_string(sig)
+ return m
+
+ def verify_ssh_sig(self, data, msg):
+ sig_algorithm = msg.get_text()
+ if sig_algorithm not in self.HASHES:
+ return False
+ key = self.key
+ if isinstance(key, rsa.RSAPrivateKey):
+ key = key.public_key()
+
+ # NOTE: pad received signature with leading zeros, key.verify()
+ # expects a signature of key size (e.g. PuTTY doesn't pad)
+ sign = msg.get_binary()
+ diff = key.key_size - len(sign) * 8
+ if diff > 0:
+ sign = b"\x00" * ((diff + 7) // 8) + sign
+
+ try:
+ key.verify(
+ sign, data, padding.PKCS1v15(), self.HASHES[sig_algorithm]()
+ )
+ except InvalidSignature:
+ return False
+ else:
+ return True
+
+ def write_private_key_file(self, filename, password=None):
+ self._write_private_key_file(
+ filename,
+ self.key,
+ serialization.PrivateFormat.TraditionalOpenSSL,
+ password=password,
+ )
+
+ def write_private_key(self, file_obj, password=None):
+ self._write_private_key(
+ file_obj,
+ self.key,
+ serialization.PrivateFormat.TraditionalOpenSSL,
+ password=password,
+ )
@staticmethod
def generate(bits, progress_func=None):
@@ -54,4 +185,43 @@ class RSAKey(PKey):
:param progress_func: Unused
:return: new `.RSAKey` private key
"""
- pass
+ key = rsa.generate_private_key(
+ public_exponent=65537, key_size=bits, backend=default_backend()
+ )
+ return RSAKey(key=key)
+
+ # ...internals...
+
+ def _from_private_key_file(self, filename, password):
+ data = self._read_private_key_file("RSA", filename, password)
+ self._decode_key(data)
+
+ def _from_private_key(self, file_obj, password):
+ data = self._read_private_key("RSA", file_obj, password)
+ self._decode_key(data)
+
+ def _decode_key(self, data):
+ pkformat, data = data
+ if pkformat == self._PRIVATE_KEY_FORMAT_ORIGINAL:
+ try:
+ key = serialization.load_der_private_key(
+ data, password=None, backend=default_backend()
+ )
+ except (ValueError, TypeError, UnsupportedAlgorithm) as e:
+ raise SSHException(str(e))
+ elif pkformat == self._PRIVATE_KEY_FORMAT_OPENSSH:
+ n, e, d, iqmp, p, q = self._uint32_cstruct_unpack(data, "iiiiii")
+ public_numbers = rsa.RSAPublicNumbers(e=e, n=n)
+ key = rsa.RSAPrivateNumbers(
+ p=p,
+ q=q,
+ d=d,
+ dmp1=d % (p - 1),
+ dmq1=d % (q - 1),
+ iqmp=iqmp,
+ public_numbers=public_numbers,
+ ).private_key(default_backend())
+ else:
+ self._got_bad_key_format_id(pkformat)
+ assert isinstance(key, rsa.RSAPrivateKey)
+ self.key = key
diff --git a/paramiko/server.py b/paramiko/server.py
index dc283021..6923bdf5 100644
--- a/paramiko/server.py
+++ b/paramiko/server.py
@@ -1,9 +1,34 @@
+# Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
"""
`.ServerInterface` is an interface to override for server support.
"""
+
import threading
from paramiko import util
-from paramiko.common import DEBUG, ERROR, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, AUTH_FAILED, AUTH_SUCCESSFUL
+from paramiko.common import (
+ DEBUG,
+ ERROR,
+ OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED,
+ AUTH_FAILED,
+ AUTH_SUCCESSFUL,
+)
class ServerInterface:
@@ -59,7 +84,7 @@ class ServerInterface:
:param int chanid: ID of the channel
:return: an `int` success or failure code (listed above)
"""
- pass
+ return OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
def get_allowed_auths(self, username):
"""
@@ -76,7 +101,7 @@ class ServerInterface:
:param str username: the username requesting authentication.
:return: a comma-separated `str` of authentication types
"""
- pass
+ return "password"
def check_auth_none(self, username):
"""
@@ -95,7 +120,7 @@ class ServerInterface:
it succeeds.
:rtype: int
"""
- pass
+ return AUTH_FAILED
def check_auth_password(self, username, password):
"""
@@ -120,7 +145,7 @@ class ServerInterface:
successful, but authentication must continue.
:rtype: int
"""
- pass
+ return AUTH_FAILED
def check_auth_publickey(self, username, key):
"""
@@ -152,7 +177,7 @@ class ServerInterface:
authentication
:rtype: int
"""
- pass
+ return AUTH_FAILED
def check_auth_interactive(self, username, submethods):
"""
@@ -177,7 +202,7 @@ class ServerInterface:
object containing queries for the user
:rtype: int or `.InteractiveQuery`
"""
- pass
+ return AUTH_FAILED
def check_auth_interactive_response(self, responses):
"""
@@ -208,10 +233,11 @@ class ServerInterface:
object containing queries for the user
:rtype: int or `.InteractiveQuery`
"""
- pass
+ return AUTH_FAILED
- def check_auth_gssapi_with_mic(self, username, gss_authenticated=
- AUTH_FAILED, cc_file=None):
+ def check_auth_gssapi_with_mic(
+ self, username, gss_authenticated=AUTH_FAILED, cc_file=None
+ ):
"""
Authenticate the given user to the server if he is a valid krb5
principal.
@@ -235,10 +261,13 @@ class ServerInterface:
log in as a user.
:see: http://www.unix.com/man-page/all/3/krb5_kuserok/
"""
- pass
+ if gss_authenticated == AUTH_SUCCESSFUL:
+ return AUTH_SUCCESSFUL
+ return AUTH_FAILED
- def check_auth_gssapi_keyex(self, username, gss_authenticated=
- AUTH_FAILED, cc_file=None):
+ def check_auth_gssapi_keyex(
+ self, username, gss_authenticated=AUTH_FAILED, cc_file=None
+ ):
"""
Authenticate the given user to the server if he is a valid krb5
principal and GSS-API Key Exchange was performed.
@@ -264,7 +293,9 @@ class ServerInterface:
to log in as a user.
:see: http://www.unix.com/man-page/all/3/krb5_kuserok/
"""
- pass
+ if gss_authenticated == AUTH_SUCCESSFUL:
+ return AUTH_SUCCESSFUL
+ return AUTH_FAILED
def enable_auth_gssapi(self):
"""
@@ -275,7 +306,8 @@ class ServerInterface:
:returns bool: Whether GSSAPI authentication is enabled.
:see: `.ssh_gss`
"""
- pass
+ UseGSSAPI = False
+ return UseGSSAPI
def check_port_forward_request(self, address, port):
"""
@@ -296,7 +328,7 @@ class ServerInterface:
the port number (`int`) that was opened for listening, or ``False``
to reject
"""
- pass
+ return False
def cancel_port_forward_request(self, address, port):
"""
@@ -337,10 +369,13 @@ class ServerInterface:
``True`` or a `tuple` of data if the request was granted; ``False``
otherwise.
"""
- pass
+ return False
- def check_channel_pty_request(self, channel, term, width, height,
- pixelwidth, pixelheight, modes):
+ # ...Channel requests...
+
+ def check_channel_pty_request(
+ self, channel, term, width, height, pixelwidth, pixelheight, modes
+ ):
"""
Determine if a pseudo-terminal of the given dimensions (usually
requested for shell access) can be provided on the given channel.
@@ -359,7 +394,7 @@ class ServerInterface:
``True`` if the pseudo-terminal has been allocated; ``False``
otherwise.
"""
- pass
+ return False
def check_channel_shell_request(self, channel):
"""
@@ -375,7 +410,7 @@ class ServerInterface:
``True`` if this channel is now hooked up to a shell; ``False`` if
a shell can't or won't be provided.
"""
- pass
+ return False
def check_channel_exec_request(self, channel, command):
"""
@@ -394,7 +429,7 @@ class ServerInterface:
.. versionadded:: 1.1
"""
- pass
+ return False
def check_channel_subsystem_request(self, channel, name):
"""
@@ -418,10 +453,17 @@ class ServerInterface:
``True`` if this channel is now hooked up to the requested
subsystem; ``False`` if that subsystem can't or won't be provided.
"""
- pass
+ transport = channel.get_transport()
+ handler_class, args, kwargs = transport._get_subsystem_handler(name)
+ if handler_class is None:
+ return False
+ handler = handler_class(channel, name, self, *args, **kwargs)
+ handler.start()
+ return True
- def check_channel_window_change_request(self, channel, width, height,
- pixelwidth, pixelheight):
+ def check_channel_window_change_request(
+ self, channel, width, height, pixelwidth, pixelheight
+ ):
"""
Determine if the pseudo-terminal on the given channel can be resized.
This only makes sense if a pty was previously allocated on it.
@@ -437,10 +479,16 @@ class ServerInterface:
height of screen in pixels, if known (may be ``0`` if unknown).
:return: ``True`` if the terminal was resized; ``False`` if not.
"""
- pass
+ return False
- def check_channel_x11_request(self, channel, single_connection,
- auth_protocol, auth_cookie, screen_number):
+ def check_channel_x11_request(
+ self,
+ channel,
+ single_connection,
+ auth_protocol,
+ auth_cookie,
+ screen_number,
+ ):
"""
Determine if the client will be provided with an X11 session. If this
method returns ``True``, X11 applications should be routed through new
@@ -457,7 +505,7 @@ class ServerInterface:
:param int screen_number: the number of the X11 screen to connect to
:return: ``True`` if the X11 session was opened; ``False`` if not
"""
- pass
+ return False
def check_channel_forward_agent_request(self, channel):
"""
@@ -473,7 +521,7 @@ class ServerInterface:
If ``True`` is returned, the server should create an
:class:`AgentServerProxy` to access the agent.
"""
- pass
+ return False
def check_channel_direct_tcpip_request(self, chanid, origin, destination):
"""
@@ -513,7 +561,7 @@ class ServerInterface:
(server side)
:return: an `int` success or failure code (listed above)
"""
- pass
+ return OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
def check_channel_env_request(self, channel, name, value):
"""
@@ -531,7 +579,7 @@ class ServerInterface:
:param str value: Channel value
:returns: A boolean
"""
- pass
+ return False
def get_banner(self):
"""
@@ -545,7 +593,7 @@ class ServerInterface:
.. versionadded:: 2.3
"""
- pass
+ return (None, None)
class InteractiveQuery:
@@ -553,7 +601,7 @@ class InteractiveQuery:
A query (set of prompts) for a user during interactive authentication.
"""
- def __init__(self, name='', instructions='', *prompts):
+ def __init__(self, name="", instructions="", *prompts):
"""
Create a new interactive query to send to the client. The name and
instructions are optional, but are generally displayed to the end
@@ -584,7 +632,7 @@ class InteractiveQuery:
``True`` (default) if the user's response should be echoed;
``False`` if not (for a password or similar)
"""
- pass
+ self.prompts.append((prompt, echo))
class SubsystemHandler(threading.Thread):
@@ -627,7 +675,26 @@ class SubsystemHandler(threading.Thread):
Return the `.ServerInterface` object associated with this channel and
subsystem.
"""
- pass
+ return self.__server
+
+ def _run(self):
+ try:
+ self.__transport._log(
+ DEBUG, "Starting handler for subsystem {}".format(self.__name)
+ )
+ self.start_subsystem(self.__name, self.__transport, self.__channel)
+ except Exception as e:
+ self.__transport._log(
+ ERROR,
+ 'Exception in subsystem handler for "{}": {}'.format(
+ self.__name, e
+ ),
+ )
+ self.__transport._log(ERROR, util.tb_strings())
+ try:
+ self.finish_subsystem()
+ except:
+ pass
def start_subsystem(self, name, transport, channel):
"""
@@ -662,4 +729,4 @@ class SubsystemHandler(threading.Thread):
.. versionadded:: 1.1
"""
- pass
+ self.__channel.close()
diff --git a/paramiko/sftp.py b/paramiko/sftp.py
index 65109f59..b3528d4e 100644
--- a/paramiko/sftp.py
+++ b/paramiko/sftp.py
@@ -1,40 +1,127 @@
+# Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
import select
import socket
import struct
+
from paramiko import util
from paramiko.common import DEBUG, byte_chr, byte_ord
from paramiko.message import Message
-(CMD_INIT, CMD_VERSION, CMD_OPEN, CMD_CLOSE, CMD_READ, CMD_WRITE, CMD_LSTAT,
- CMD_FSTAT, CMD_SETSTAT, CMD_FSETSTAT, CMD_OPENDIR, CMD_READDIR,
- CMD_REMOVE, CMD_MKDIR, CMD_RMDIR, CMD_REALPATH, CMD_STAT, CMD_RENAME,
- CMD_READLINK, CMD_SYMLINK) = range(1, 21)
-CMD_STATUS, CMD_HANDLE, CMD_DATA, CMD_NAME, CMD_ATTRS = range(101, 106)
-CMD_EXTENDED, CMD_EXTENDED_REPLY = range(200, 202)
+
+
+(
+ CMD_INIT,
+ CMD_VERSION,
+ CMD_OPEN,
+ CMD_CLOSE,
+ CMD_READ,
+ CMD_WRITE,
+ CMD_LSTAT,
+ CMD_FSTAT,
+ CMD_SETSTAT,
+ CMD_FSETSTAT,
+ CMD_OPENDIR,
+ CMD_READDIR,
+ CMD_REMOVE,
+ CMD_MKDIR,
+ CMD_RMDIR,
+ CMD_REALPATH,
+ CMD_STAT,
+ CMD_RENAME,
+ CMD_READLINK,
+ CMD_SYMLINK,
+) = range(1, 21)
+(CMD_STATUS, CMD_HANDLE, CMD_DATA, CMD_NAME, CMD_ATTRS) = range(101, 106)
+(CMD_EXTENDED, CMD_EXTENDED_REPLY) = range(200, 202)
+
SFTP_OK = 0
-(SFTP_EOF, SFTP_NO_SUCH_FILE, SFTP_PERMISSION_DENIED, SFTP_FAILURE,
- SFTP_BAD_MESSAGE, SFTP_NO_CONNECTION, SFTP_CONNECTION_LOST,
- SFTP_OP_UNSUPPORTED) = range(1, 9)
-SFTP_DESC = ['Success', 'End of file', 'No such file', 'Permission denied',
- 'Failure', 'Bad message', 'No connection', 'Connection lost',
- 'Operation unsupported']
-SFTP_FLAG_READ = 1
-SFTP_FLAG_WRITE = 2
-SFTP_FLAG_APPEND = 4
-SFTP_FLAG_CREATE = 8
-SFTP_FLAG_TRUNC = 16
-SFTP_FLAG_EXCL = 32
+(
+ SFTP_EOF,
+ SFTP_NO_SUCH_FILE,
+ SFTP_PERMISSION_DENIED,
+ SFTP_FAILURE,
+ SFTP_BAD_MESSAGE,
+ SFTP_NO_CONNECTION,
+ SFTP_CONNECTION_LOST,
+ SFTP_OP_UNSUPPORTED,
+) = range(1, 9)
+
+SFTP_DESC = [
+ "Success",
+ "End of file",
+ "No such file",
+ "Permission denied",
+ "Failure",
+ "Bad message",
+ "No connection",
+ "Connection lost",
+ "Operation unsupported",
+]
+
+SFTP_FLAG_READ = 0x1
+SFTP_FLAG_WRITE = 0x2
+SFTP_FLAG_APPEND = 0x4
+SFTP_FLAG_CREATE = 0x8
+SFTP_FLAG_TRUNC = 0x10
+SFTP_FLAG_EXCL = 0x20
+
_VERSION = 3
-CMD_NAMES = {CMD_INIT: 'init', CMD_VERSION: 'version', CMD_OPEN: 'open',
- CMD_CLOSE: 'close', CMD_READ: 'read', CMD_WRITE: 'write', CMD_LSTAT:
- 'lstat', CMD_FSTAT: 'fstat', CMD_SETSTAT: 'setstat', CMD_FSETSTAT:
- 'fsetstat', CMD_OPENDIR: 'opendir', CMD_READDIR: 'readdir', CMD_REMOVE:
- 'remove', CMD_MKDIR: 'mkdir', CMD_RMDIR: 'rmdir', CMD_REALPATH:
- 'realpath', CMD_STAT: 'stat', CMD_RENAME: 'rename', CMD_READLINK:
- 'readlink', CMD_SYMLINK: 'symlink', CMD_STATUS: 'status', CMD_HANDLE:
- 'handle', CMD_DATA: 'data', CMD_NAME: 'name', CMD_ATTRS: 'attrs',
- CMD_EXTENDED: 'extended', CMD_EXTENDED_REPLY: 'extended_reply'}
+# for debugging
+CMD_NAMES = {
+ CMD_INIT: "init",
+ CMD_VERSION: "version",
+ CMD_OPEN: "open",
+ CMD_CLOSE: "close",
+ CMD_READ: "read",
+ CMD_WRITE: "write",
+ CMD_LSTAT: "lstat",
+ CMD_FSTAT: "fstat",
+ CMD_SETSTAT: "setstat",
+ CMD_FSETSTAT: "fsetstat",
+ CMD_OPENDIR: "opendir",
+ CMD_READDIR: "readdir",
+ CMD_REMOVE: "remove",
+ CMD_MKDIR: "mkdir",
+ CMD_RMDIR: "rmdir",
+ CMD_REALPATH: "realpath",
+ CMD_STAT: "stat",
+ CMD_RENAME: "rename",
+ CMD_READLINK: "readlink",
+ CMD_SYMLINK: "symlink",
+ CMD_STATUS: "status",
+ CMD_HANDLE: "handle",
+ CMD_DATA: "data",
+ CMD_NAME: "name",
+ CMD_ATTRS: "attrs",
+ CMD_EXTENDED: "extended",
+ CMD_EXTENDED_REPLY: "extended_reply",
+}
+
+
+# TODO: rewrite SFTP file/server modules' overly-flexible "make a request with
+# xyz components" so we don't need this very silly method of signaling whether
+# a given Python integer should be 32- or 64-bit.
+# NOTE: this only became an issue when dropping Python 2 support; prior to
+# doing so, we had to support actual-longs, which served as that signal. This
+# is simply recreating that structure in a more tightly scoped fashion.
class int64(int):
pass
@@ -44,8 +131,94 @@ class SFTPError(Exception):
class BaseSFTP:
-
def __init__(self):
- self.logger = util.get_logger('paramiko.sftp')
+ self.logger = util.get_logger("paramiko.sftp")
self.sock = None
self.ultra_debug = False
+
+ # ...internals...
+
+ def _send_version(self):
+ m = Message()
+ m.add_int(_VERSION)
+ self._send_packet(CMD_INIT, m)
+ t, data = self._read_packet()
+ if t != CMD_VERSION:
+ raise SFTPError("Incompatible sftp protocol")
+ version = struct.unpack(">I", data[:4])[0]
+ # if version != _VERSION:
+ # raise SFTPError('Incompatible sftp protocol')
+ return version
+
+ def _send_server_version(self):
+ # winscp will freak out if the server sends version info before the
+ # client finishes sending INIT.
+ t, data = self._read_packet()
+ if t != CMD_INIT:
+ raise SFTPError("Incompatible sftp protocol")
+ version = struct.unpack(">I", data[:4])[0]
+ # advertise that we support "check-file"
+ extension_pairs = ["check-file", "md5,sha1"]
+ msg = Message()
+ msg.add_int(_VERSION)
+ msg.add(*extension_pairs)
+ self._send_packet(CMD_VERSION, msg)
+ return version
+
+ def _log(self, level, msg, *args):
+ self.logger.log(level, msg, *args)
+
+ def _write_all(self, out):
+ while len(out) > 0:
+ n = self.sock.send(out)
+ if n <= 0:
+ raise EOFError()
+ if n == len(out):
+ return
+ out = out[n:]
+ return
+
+ def _read_all(self, n):
+ out = bytes()
+ while n > 0:
+ if isinstance(self.sock, socket.socket):
+ # sometimes sftp is used directly over a socket instead of
+ # through a paramiko channel. in this case, check periodically
+ # if the socket is closed. (for some reason, recv() won't ever
+ # return or raise an exception, but calling select on a closed
+ # socket will.)
+ while True:
+ read, write, err = select.select([self.sock], [], [], 0.1)
+ if len(read) > 0:
+ x = self.sock.recv(n)
+ break
+ else:
+ x = self.sock.recv(n)
+
+ if len(x) == 0:
+ raise EOFError()
+ out += x
+ n -= len(x)
+ return out
+
+ def _send_packet(self, t, packet):
+ packet = packet.asbytes()
+ out = struct.pack(">I", len(packet) + 1) + byte_chr(t) + packet
+ if self.ultra_debug:
+ self._log(DEBUG, util.format_binary(out, "OUT: "))
+ self._write_all(out)
+
+ def _read_packet(self):
+ x = self._read_all(4)
+ # most sftp servers won't accept packets larger than about 32k, so
+ # anything with the high byte set (> 16MB) is just garbage.
+ if byte_ord(x[0]):
+ raise SFTPError("Garbage packet received")
+ size = struct.unpack(">I", x)[0]
+ data = self._read_all(size)
+ if self.ultra_debug:
+ self._log(DEBUG, util.format_binary(data, "IN: "))
+ if size > 0:
+ t = byte_ord(data[0])
+ return t, data[1:]
+ return 0, bytes()
diff --git a/paramiko/sftp_attr.py b/paramiko/sftp_attr.py
index 0745a134..18ffbf86 100644
--- a/paramiko/sftp_attr.py
+++ b/paramiko/sftp_attr.py
@@ -1,3 +1,21 @@
+# Copyright (C) 2003-2006 Robey Pointer <robeypointer@gmail.com>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
import stat
import time
from paramiko.common import x80000000, o700, o70, xffffffff
@@ -21,6 +39,7 @@ class SFTPAttributes:
are stored in a dict named ``attr``. Occasionally, the filename is also
stored, in ``filename``.
"""
+
FLAG_SIZE = 1
FLAG_UIDGID = 2
FLAG_PERMISSIONS = 4
@@ -50,48 +69,150 @@ class SFTPAttributes:
:param str filename: the filename associated with this file.
:return: new `.SFTPAttributes` object with the same attribute fields.
"""
- pass
+ attr = cls()
+ attr.st_size = obj.st_size
+ attr.st_uid = obj.st_uid
+ attr.st_gid = obj.st_gid
+ attr.st_mode = obj.st_mode
+ attr.st_atime = obj.st_atime
+ attr.st_mtime = obj.st_mtime
+ if filename is not None:
+ attr.filename = filename
+ return attr
def __repr__(self):
- return '<SFTPAttributes: {}>'.format(self._debug_str())
+ return "<SFTPAttributes: {}>".format(self._debug_str())
+
+ # ...internals...
+ @classmethod
+ def _from_msg(cls, msg, filename=None, longname=None):
+ attr = cls()
+ attr._unpack(msg)
+ if filename is not None:
+ attr.filename = filename
+ if longname is not None:
+ attr.longname = longname
+ return attr
+
+ def _unpack(self, msg):
+ self._flags = msg.get_int()
+ if self._flags & self.FLAG_SIZE:
+ self.st_size = msg.get_int64()
+ if self._flags & self.FLAG_UIDGID:
+ self.st_uid = msg.get_int()
+ self.st_gid = msg.get_int()
+ if self._flags & self.FLAG_PERMISSIONS:
+ self.st_mode = msg.get_int()
+ if self._flags & self.FLAG_AMTIME:
+ self.st_atime = msg.get_int()
+ self.st_mtime = msg.get_int()
+ if self._flags & self.FLAG_EXTENDED:
+ count = msg.get_int()
+ for i in range(count):
+ self.attr[msg.get_string()] = msg.get_string()
+
+ def _pack(self, msg):
+ self._flags = 0
+ if self.st_size is not None:
+ self._flags |= self.FLAG_SIZE
+ if (self.st_uid is not None) and (self.st_gid is not None):
+ self._flags |= self.FLAG_UIDGID
+ if self.st_mode is not None:
+ self._flags |= self.FLAG_PERMISSIONS
+ if (self.st_atime is not None) and (self.st_mtime is not None):
+ self._flags |= self.FLAG_AMTIME
+ if len(self.attr) > 0:
+ self._flags |= self.FLAG_EXTENDED
+ msg.add_int(self._flags)
+ if self._flags & self.FLAG_SIZE:
+ msg.add_int64(self.st_size)
+ if self._flags & self.FLAG_UIDGID:
+ msg.add_int(self.st_uid)
+ msg.add_int(self.st_gid)
+ if self._flags & self.FLAG_PERMISSIONS:
+ msg.add_int(self.st_mode)
+ if self._flags & self.FLAG_AMTIME:
+ # throw away any fractional seconds
+ msg.add_int(int(self.st_atime))
+ msg.add_int(int(self.st_mtime))
+ if self._flags & self.FLAG_EXTENDED:
+ msg.add_int(len(self.attr))
+ for key, val in self.attr.items():
+ msg.add_string(key)
+ msg.add_string(val)
+ return
+
+ def _debug_str(self):
+ out = "[ "
+ if self.st_size is not None:
+ out += "size={} ".format(self.st_size)
+ if (self.st_uid is not None) and (self.st_gid is not None):
+ out += "uid={} gid={} ".format(self.st_uid, self.st_gid)
+ if self.st_mode is not None:
+ out += "mode=" + oct(self.st_mode) + " "
+ if (self.st_atime is not None) and (self.st_mtime is not None):
+ out += "atime={} mtime={} ".format(self.st_atime, self.st_mtime)
+ for k, v in self.attr.items():
+ out += '"{}"={!r} '.format(str(k), v)
+ out += "]"
+ return out
+
+ @staticmethod
+ def _rwx(n, suid, sticky=False):
+ if suid:
+ suid = 2
+ out = "-r"[n >> 2] + "-w"[(n >> 1) & 1]
+ if sticky:
+ out += "-xTt"[suid + (n & 1)]
+ else:
+ out += "-xSs"[suid + (n & 1)]
+ return out
def __str__(self):
"""create a unix-style long description of the file (like ls -l)"""
if self.st_mode is not None:
kind = stat.S_IFMT(self.st_mode)
if kind == stat.S_IFIFO:
- ks = 'p'
+ ks = "p"
elif kind == stat.S_IFCHR:
- ks = 'c'
+ ks = "c"
elif kind == stat.S_IFDIR:
- ks = 'd'
+ ks = "d"
elif kind == stat.S_IFBLK:
- ks = 'b'
+ ks = "b"
elif kind == stat.S_IFREG:
- ks = '-'
+ ks = "-"
elif kind == stat.S_IFLNK:
- ks = 'l'
+ ks = "l"
elif kind == stat.S_IFSOCK:
- ks = 's'
+ ks = "s"
else:
- ks = '?'
- ks += self._rwx((self.st_mode & o700) >> 6, self.st_mode & stat
- .S_ISUID)
- ks += self._rwx((self.st_mode & o70) >> 3, self.st_mode & stat.
- S_ISGID)
- ks += self._rwx(self.st_mode & 7, self.st_mode & stat.S_ISVTX, True
- )
+ ks = "?"
+ ks += self._rwx(
+ (self.st_mode & o700) >> 6, self.st_mode & stat.S_ISUID
+ )
+ ks += self._rwx(
+ (self.st_mode & o70) >> 3, self.st_mode & stat.S_ISGID
+ )
+ ks += self._rwx(
+ self.st_mode & 7, self.st_mode & stat.S_ISVTX, True
+ )
else:
- ks = '?---------'
- if self.st_mtime is None or self.st_mtime == xffffffff:
- datestr = '(unknown date)'
+ ks = "?---------"
+ # compute display date
+ if (self.st_mtime is None) or (self.st_mtime == xffffffff):
+ # shouldn't really happen
+ datestr = "(unknown date)"
else:
time_tuple = time.localtime(self.st_mtime)
- if abs(time.time() - self.st_mtime) > 15552000:
- datestr = time.strftime('%d %b %Y', time_tuple)
+ if abs(time.time() - self.st_mtime) > 15_552_000:
+ # (15,552,000s = 6 months)
+ datestr = time.strftime("%d %b %Y", time_tuple)
else:
- datestr = time.strftime('%d %b %H:%M', time_tuple)
- filename = getattr(self, 'filename', '?')
+ datestr = time.strftime("%d %b %H:%M", time_tuple)
+ filename = getattr(self, "filename", "?")
+
+ # not all servers support uid/gid
uid = self.st_uid
gid = self.st_gid
size = self.st_size
@@ -101,5 +222,18 @@ class SFTPAttributes:
gid = 0
if size is None:
size = 0
- return '%s 1 %-8d %-8d %8d %-12s %s' % (ks, uid, gid, size,
- datestr, filename)
+
+ # TODO: not sure this actually worked as expected beforehand, leaving
+ # it untouched for the time being, re: .format() upgrade, until someone
+ # has time to doublecheck
+ return "%s 1 %-8d %-8d %8d %-12s %s" % (
+ ks,
+ uid,
+ gid,
+ size,
+ datestr,
+ filename,
+ )
+
+ def asbytes(self):
+ return str(self).encode()
diff --git a/paramiko/sftp_client.py b/paramiko/sftp_client.py
index 24ff487a..066cd83f 100644
--- a/paramiko/sftp_client.py
+++ b/paramiko/sftp_client.py
@@ -1,3 +1,22 @@
+# Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
+#
+# This file is part of Paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
+
from binascii import hexlify
import errno
import os
@@ -9,7 +28,41 @@ from paramiko import util
from paramiko.channel import Channel
from paramiko.message import Message
from paramiko.common import INFO, DEBUG, o777
-from paramiko.sftp import BaseSFTP, CMD_OPENDIR, CMD_HANDLE, SFTPError, CMD_READDIR, CMD_NAME, CMD_CLOSE, SFTP_FLAG_READ, SFTP_FLAG_WRITE, SFTP_FLAG_CREATE, SFTP_FLAG_TRUNC, SFTP_FLAG_APPEND, SFTP_FLAG_EXCL, CMD_OPEN, CMD_REMOVE, CMD_RENAME, CMD_MKDIR, CMD_RMDIR, CMD_STAT, CMD_ATTRS, CMD_LSTAT, CMD_SYMLINK, CMD_SETSTAT, CMD_READLINK, CMD_REALPATH, CMD_STATUS, CMD_EXTENDED, SFTP_OK, SFTP_EOF, SFTP_NO_SUCH_FILE, SFTP_PERMISSION_DENIED, int64
+from paramiko.sftp import (
+ BaseSFTP,
+ CMD_OPENDIR,
+ CMD_HANDLE,
+ SFTPError,
+ CMD_READDIR,
+ CMD_NAME,
+ CMD_CLOSE,
+ SFTP_FLAG_READ,
+ SFTP_FLAG_WRITE,
+ SFTP_FLAG_CREATE,
+ SFTP_FLAG_TRUNC,
+ SFTP_FLAG_APPEND,
+ SFTP_FLAG_EXCL,
+ CMD_OPEN,
+ CMD_REMOVE,
+ CMD_RENAME,
+ CMD_MKDIR,
+ CMD_RMDIR,
+ CMD_STAT,
+ CMD_ATTRS,
+ CMD_LSTAT,
+ CMD_SYMLINK,
+ CMD_SETSTAT,
+ CMD_READLINK,
+ CMD_REALPATH,
+ CMD_STATUS,
+ CMD_EXTENDED,
+ SFTP_OK,
+ SFTP_EOF,
+ SFTP_NO_SUCH_FILE,
+ SFTP_PERMISSION_DENIED,
+ int64,
+)
+
from paramiko.sftp_attr import SFTPAttributes
from paramiko.ssh_exception import SSHException
from paramiko.sftp_file import SFTPFile
@@ -22,10 +75,16 @@ def _to_unicode(s):
protocol). if neither works, just return a byte string because the server
probably doesn't know the filename's encoding.
"""
- pass
+ try:
+ return s.encode("ascii")
+ except (UnicodeError, AttributeError):
+ try:
+ return s.decode("utf-8")
+ except UnicodeError:
+ return s
-b_slash = b'/'
+b_slash = b"/"
class SFTPClient(BaseSFTP, ClosingContextManager):
@@ -55,20 +114,28 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
self.sock = sock
self.ultra_debug = False
self.request_number = 1
+ # lock for request_number
self._lock = threading.Lock()
self._cwd = None
+ # request # -> SFTPFile
self._expecting = weakref.WeakValueDictionary()
if type(sock) is Channel:
+ # override default logger
transport = self.sock.get_transport()
- self.logger = util.get_logger(transport.get_log_channel() + '.sftp'
- )
+ self.logger = util.get_logger(
+ transport.get_log_channel() + ".sftp"
+ )
self.ultra_debug = transport.get_hexdump()
try:
server_version = self._send_version()
except EOFError:
- raise SSHException('EOF during negotiation')
- self._log(INFO, 'Opened sftp connection (server version {})'.format
- (server_version))
+ raise SSHException("EOF during negotiation")
+ self._log(
+ INFO,
+ "Opened sftp connection (server version {})".format(
+ server_version
+ ),
+ )
@classmethod
def from_transport(cls, t, window_size=None, max_packet_size=None):
@@ -94,7 +161,29 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
.. versionchanged:: 1.15
Added the ``window_size`` and ``max_packet_size`` arguments.
"""
- pass
+ chan = t.open_session(
+ window_size=window_size, max_packet_size=max_packet_size
+ )
+ if chan is None:
+ return None
+ chan.invoke_subsystem("sftp")
+ return cls(chan)
+
+ def _log(self, level, msg, *args):
+ if isinstance(msg, list):
+ for m in msg:
+ self._log(level, m, *args)
+ else:
+ # NOTE: these bits MUST continue using %-style format junk because
+ # logging.Logger.log() explicitly requires it. Grump.
+ # escape '%' in msg (they could come from file or directory names)
+ # before logging
+ msg = msg.replace("%", "%%")
+ super()._log(
+ level,
+ "[chan %s] " + msg,
+ *([self.sock.get_name()] + list(args))
+ )
def close(self):
"""
@@ -102,7 +191,8 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
.. versionadded:: 1.4
"""
- pass
+ self._log(INFO, "sftp session closed.")
+ self.sock.close()
def get_channel(self):
"""
@@ -111,9 +201,9 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
.. versionadded:: 1.7.1
"""
- pass
+ return self.sock
- def listdir(self, path='.'):
+ def listdir(self, path="."):
"""
Return a list containing the names of the entries in the given
``path``.
@@ -125,9 +215,9 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
:param str path: path to list (defaults to ``'.'``)
"""
- pass
+ return [f.filename for f in self.listdir_attr(path)]
- def listdir_attr(self, path='.'):
+ def listdir_attr(self, path="."):
"""
Return a list containing `.SFTPAttributes` objects corresponding to
files in the given ``path``. The list is in arbitrary order. It does
@@ -144,9 +234,32 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
.. versionadded:: 1.2
"""
- pass
-
- def listdir_iter(self, path='.', read_aheads=50):
+ path = self._adjust_cwd(path)
+ self._log(DEBUG, "listdir({!r})".format(path))
+ t, msg = self._request(CMD_OPENDIR, path)
+ if t != CMD_HANDLE:
+ raise SFTPError("Expected handle")
+ handle = msg.get_binary()
+ filelist = []
+ while True:
+ try:
+ t, msg = self._request(CMD_READDIR, handle)
+ except EOFError:
+ # done with handle
+ break
+ if t != CMD_NAME:
+ raise SFTPError("Expected name response")
+ count = msg.get_int()
+ for i in range(count):
+ filename = msg.get_text()
+ longname = msg.get_text()
+ attr = SFTPAttributes._from_msg(msg, filename, longname)
+ if (filename != ".") and (filename != ".."):
+ filelist.append(attr)
+ self._request(CMD_CLOSE, handle)
+ return filelist
+
+ def listdir_iter(self, path=".", read_aheads=50):
"""
Generator version of `.listdir_attr`.
@@ -160,9 +273,57 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
.. versionadded:: 1.15
"""
- pass
-
- def open(self, filename, mode='r', bufsize=-1):
+ path = self._adjust_cwd(path)
+ self._log(DEBUG, "listdir({!r})".format(path))
+ t, msg = self._request(CMD_OPENDIR, path)
+
+ if t != CMD_HANDLE:
+ raise SFTPError("Expected handle")
+
+ handle = msg.get_string()
+
+ nums = list()
+ while True:
+ try:
+ # Send out a bunch of readdir requests so that we can read the
+ # responses later on Section 6.7 of the SSH file transfer RFC
+ # explains this
+ # http://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt
+ for i in range(read_aheads):
+ num = self._async_request(type(None), CMD_READDIR, handle)
+ nums.append(num)
+
+ # For each of our sent requests
+ # Read and parse the corresponding packets
+ # If we're at the end of our queued requests, then fire off
+ # some more requests
+ # Exit the loop when we've reached the end of the directory
+ # handle
+ for num in nums:
+ t, pkt_data = self._read_packet()
+ msg = Message(pkt_data)
+ new_num = msg.get_int()
+ if num == new_num:
+ if t == CMD_STATUS:
+ self._convert_status(msg)
+ count = msg.get_int()
+ for i in range(count):
+ filename = msg.get_text()
+ longname = msg.get_text()
+ attr = SFTPAttributes._from_msg(
+ msg, filename, longname
+ )
+ if (filename != ".") and (filename != ".."):
+ yield attr
+
+ # If we've hit the end of our queued requests, reset nums.
+ nums = list()
+
+ except EOFError:
+ self._request(CMD_CLOSE, handle)
+ return
+
+ def open(self, filename, mode="r", bufsize=-1):
"""
Open a file on the remote server. The arguments are the same as for
Python's built-in `python:file` (aka `python:open`). A file-like
@@ -194,7 +355,33 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
:raises: ``IOError`` -- if the file could not be opened.
"""
- pass
+ filename = self._adjust_cwd(filename)
+ self._log(DEBUG, "open({!r}, {!r})".format(filename, mode))
+ imode = 0
+ if ("r" in mode) or ("+" in mode):
+ imode |= SFTP_FLAG_READ
+ if ("w" in mode) or ("+" in mode) or ("a" in mode):
+ imode |= SFTP_FLAG_WRITE
+ if "w" in mode:
+ imode |= SFTP_FLAG_CREATE | SFTP_FLAG_TRUNC
+ if "a" in mode:
+ imode |= SFTP_FLAG_CREATE | SFTP_FLAG_APPEND
+ if "x" in mode:
+ imode |= SFTP_FLAG_CREATE | SFTP_FLAG_EXCL
+ attrblock = SFTPAttributes()
+ t, msg = self._request(CMD_OPEN, filename, imode, attrblock)
+ if t != CMD_HANDLE:
+ raise SFTPError("Expected handle")
+ handle = msg.get_binary()
+ self._log(
+ DEBUG,
+ "open({!r}, {!r}) -> {}".format(
+ filename, mode, u(hexlify(handle))
+ ),
+ )
+ return SFTPFile(self, handle, mode, bufsize)
+
+ # Python continues to vacillate about "open" vs "file"...
file = open
def remove(self, path):
@@ -206,7 +393,10 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
:raises: ``IOError`` -- if the path refers to a folder (directory)
"""
- pass
+ path = self._adjust_cwd(path)
+ self._log(DEBUG, "remove({!r})".format(path))
+ self._request(CMD_REMOVE, path)
+
unlink = remove
def rename(self, oldpath, newpath):
@@ -227,7 +417,10 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
``IOError`` -- if ``newpath`` is a folder, or something else goes
wrong
"""
- pass
+ oldpath = self._adjust_cwd(oldpath)
+ newpath = self._adjust_cwd(newpath)
+ self._log(DEBUG, "rename({!r}, {!r})".format(oldpath, newpath))
+ self._request(CMD_RENAME, oldpath, newpath)
def posix_rename(self, oldpath, newpath):
"""
@@ -244,7 +437,12 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
:versionadded: 2.2
"""
- pass
+ oldpath = self._adjust_cwd(oldpath)
+ newpath = self._adjust_cwd(newpath)
+ self._log(DEBUG, "posix_rename({!r}, {!r})".format(oldpath, newpath))
+ self._request(
+ CMD_EXTENDED, "posix-rename@openssh.com", oldpath, newpath
+ )
def mkdir(self, path, mode=o777):
"""
@@ -255,7 +453,11 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
:param str path: name of the folder to create
:param int mode: permissions (posix-style) for the newly-created folder
"""
- pass
+ path = self._adjust_cwd(path)
+ self._log(DEBUG, "mkdir({!r}, {!r})".format(path, mode))
+ attr = SFTPAttributes()
+ attr.st_mode = mode
+ self._request(CMD_MKDIR, path, attr)
def rmdir(self, path):
"""
@@ -263,7 +465,9 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
:param str path: name of the folder to remove
"""
- pass
+ path = self._adjust_cwd(path)
+ self._log(DEBUG, "rmdir({!r})".format(path))
+ self._request(CMD_RMDIR, path)
def stat(self, path):
"""
@@ -284,7 +488,12 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
an `.SFTPAttributes` object containing attributes about the given
file
"""
- pass
+ path = self._adjust_cwd(path)
+ self._log(DEBUG, "stat({!r})".format(path))
+ t, msg = self._request(CMD_STAT, path)
+ if t != CMD_ATTRS:
+ raise SFTPError("Expected attributes")
+ return SFTPAttributes._from_msg(msg)
def lstat(self, path):
"""
@@ -297,7 +506,12 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
an `.SFTPAttributes` object containing attributes about the given
file
"""
- pass
+ path = self._adjust_cwd(path)
+ self._log(DEBUG, "lstat({!r})".format(path))
+ t, msg = self._request(CMD_LSTAT, path)
+ if t != CMD_ATTRS:
+ raise SFTPError("Expected attributes")
+ return SFTPAttributes._from_msg(msg)
def symlink(self, source, dest):
"""
@@ -306,7 +520,10 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
:param str source: path of the original file
:param str dest: path of the newly created symlink
"""
- pass
+ dest = self._adjust_cwd(dest)
+ self._log(DEBUG, "symlink({!r}, {!r})".format(source, dest))
+ source = b(source)
+ self._request(CMD_SYMLINK, source, dest)
def chmod(self, path, mode):
"""
@@ -317,7 +534,11 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
:param str path: path of the file to change the permissions of
:param int mode: new permissions
"""
- pass
+ path = self._adjust_cwd(path)
+ self._log(DEBUG, "chmod({!r}, {!r})".format(path, mode))
+ attr = SFTPAttributes()
+ attr.st_mode = mode
+ self._request(CMD_SETSTAT, path, attr)
def chown(self, path, uid, gid):
"""
@@ -330,7 +551,11 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
:param int uid: new owner's uid
:param int gid: new group id
"""
- pass
+ path = self._adjust_cwd(path)
+ self._log(DEBUG, "chown({!r}, {!r}, {!r})".format(path, uid, gid))
+ attr = SFTPAttributes()
+ attr.st_uid, attr.st_gid = uid, gid
+ self._request(CMD_SETSTAT, path, attr)
def utime(self, path, times):
"""
@@ -346,7 +571,13 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
``None`` or a tuple of (access time, modified time) in standard
internet epoch time (seconds since 01 January 1970 GMT)
"""
- pass
+ path = self._adjust_cwd(path)
+ if times is None:
+ times = (time.time(), time.time())
+ self._log(DEBUG, "utime({!r}, {!r})".format(path, times))
+ attr = SFTPAttributes()
+ attr.st_atime, attr.st_mtime = times
+ self._request(CMD_SETSTAT, path, attr)
def truncate(self, path, size):
"""
@@ -357,7 +588,11 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
:param str path: path of the file to modify
:param int size: the new size of the file
"""
- pass
+ path = self._adjust_cwd(path)
+ self._log(DEBUG, "truncate({!r}, {!r})".format(path, size))
+ attr = SFTPAttributes()
+ attr.st_size = size
+ self._request(CMD_SETSTAT, path, attr)
def readlink(self, path):
"""
@@ -368,7 +603,17 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
:param str path: path of the symbolic link file
:return: target path, as a `str`
"""
- pass
+ path = self._adjust_cwd(path)
+ self._log(DEBUG, "readlink({!r})".format(path))
+ t, msg = self._request(CMD_READLINK, path)
+ if t != CMD_NAME:
+ raise SFTPError("Expected name response")
+ count = msg.get_int()
+ if count == 0:
+ return None
+ if count != 1:
+ raise SFTPError("Readlink returned {} results".format(count))
+ return _to_unicode(msg.get_string())
def normalize(self, path):
"""
@@ -382,7 +627,15 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
:raises: ``IOError`` -- if the path can't be resolved on the server
"""
- pass
+ path = self._adjust_cwd(path)
+ self._log(DEBUG, "normalize({!r})".format(path))
+ t, msg = self._request(CMD_REALPATH, path)
+ if t != CMD_NAME:
+ raise SFTPError("Expected name response")
+ count = msg.get_int()
+ if count != 1:
+ raise SFTPError("Realpath returned {} results".format(count))
+ return msg.get_text()
def chdir(self, path=None):
"""
@@ -400,7 +653,13 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
.. versionadded:: 1.4
"""
- pass
+ if path is None:
+ self._cwd = None
+ return
+ if not stat.S_ISDIR(self.stat(path).st_mode):
+ code = errno.ENOTDIR
+ raise SFTPError(code, "{}: {}".format(os.strerror(code), path))
+ self._cwd = b(self.normalize(path))
def getcwd(self):
"""
@@ -410,7 +669,20 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
.. versionadded:: 1.4
"""
- pass
+ # TODO: make class initialize with self._cwd set to self.normalize('.')
+ return self._cwd and u(self._cwd)
+
+ def _transfer_with_callback(self, reader, writer, file_size, callback):
+ size = 0
+ while True:
+ data = reader.read(32768)
+ writer.write(data)
+ size += len(data)
+ if len(data) == 0:
+ break
+ if callback is not None:
+ callback(size, file_size)
+ return size
def putfo(self, fl, remotepath, file_size=0, callback=None, confirm=True):
"""
@@ -439,7 +711,20 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
.. versionadded:: 1.10
"""
- pass
+ with self.file(remotepath, "wb") as fr:
+ fr.set_pipelined(True)
+ size = self._transfer_with_callback(
+ reader=fl, writer=fr, file_size=file_size, callback=callback
+ )
+ if confirm:
+ s = self.stat(remotepath)
+ if s.st_size != size:
+ raise IOError(
+ "size mismatch in put! {} != {}".format(s.st_size, size)
+ )
+ else:
+ s = SFTPAttributes()
+ return s
def put(self, localpath, remotepath, callback=None, confirm=True):
"""
@@ -469,10 +754,18 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
.. versionchanged:: 1.7.7
``confirm`` param added.
"""
- pass
+ file_size = os.stat(localpath).st_size
+ with open(localpath, "rb") as fl:
+ return self.putfo(fl, remotepath, file_size, callback, confirm)
- def getfo(self, remotepath, fl, callback=None, prefetch=True,
- max_concurrent_prefetch_requests=None):
+ def getfo(
+ self,
+ remotepath,
+ fl,
+ callback=None,
+ prefetch=True,
+ max_concurrent_prefetch_requests=None,
+ ):
"""
Copy a remote file (``remotepath``) from the SFTP server and write to
an open file or file-like object, ``fl``. Any exception raised by
@@ -499,10 +792,22 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
.. versionchanged:: 3.3
Added ``max_concurrent_prefetch_requests``.
"""
- pass
-
- def get(self, remotepath, localpath, callback=None, prefetch=True,
- max_concurrent_prefetch_requests=None):
+ file_size = self.stat(remotepath).st_size
+ with self.open(remotepath, "rb") as fr:
+ if prefetch:
+ fr.prefetch(file_size, max_concurrent_prefetch_requests)
+ return self._transfer_with_callback(
+ reader=fr, writer=fl, file_size=file_size, callback=callback
+ )
+
+ def get(
+ self,
+ remotepath,
+ localpath,
+ callback=None,
+ prefetch=True,
+ max_concurrent_prefetch_requests=None,
+ ):
"""
Copy a remote file (``remotepath``) from the SFTP server to the local
host as ``localpath``. Any exception raised by operations will be
@@ -531,24 +836,130 @@ class SFTPClient(BaseSFTP, ClosingContextManager):
.. versionchanged:: 3.3
Added ``max_concurrent_prefetch_requests``.
"""
- pass
+ with open(localpath, "wb") as fl:
+ size = self.getfo(
+ remotepath,
+ fl,
+ callback,
+ prefetch,
+ max_concurrent_prefetch_requests,
+ )
+ s = os.stat(localpath)
+ if s.st_size != size:
+ raise IOError(
+ "size mismatch in get! {} != {}".format(s.st_size, size)
+ )
+
+ # ...internals...
+
+ def _request(self, t, *args):
+ num = self._async_request(type(None), t, *args)
+ return self._read_response(num)
+
+ def _async_request(self, fileobj, t, *args):
+ # this method may be called from other threads (prefetch)
+ self._lock.acquire()
+ try:
+ msg = Message()
+ msg.add_int(self.request_number)
+ for item in args:
+ if isinstance(item, int64):
+ msg.add_int64(item)
+ elif isinstance(item, int):
+ msg.add_int(item)
+ elif isinstance(item, SFTPAttributes):
+ item._pack(msg)
+ else:
+ # For all other types, rely on as_string() to either coerce
+ # to bytes before writing or raise a suitable exception.
+ msg.add_string(item)
+ num = self.request_number
+ self._expecting[num] = fileobj
+ self.request_number += 1
+ finally:
+ self._lock.release()
+ self._send_packet(t, msg)
+ return num
+
+ def _read_response(self, waitfor=None):
+ while True:
+ try:
+ t, data = self._read_packet()
+ except EOFError as e:
+ raise SSHException("Server connection dropped: {}".format(e))
+ msg = Message(data)
+ num = msg.get_int()
+ self._lock.acquire()
+ try:
+ if num not in self._expecting:
+ # might be response for a file that was closed before
+ # responses came back
+ self._log(DEBUG, "Unexpected response #{}".format(num))
+ if waitfor is None:
+ # just doing a single check
+ break
+ continue
+ fileobj = self._expecting[num]
+ del self._expecting[num]
+ finally:
+ self._lock.release()
+ if num == waitfor:
+ # synchronous
+ if t == CMD_STATUS:
+ self._convert_status(msg)
+ return t, msg
+
+ # can not rewrite this to deal with E721, either as a None check
+ # nor as not an instance of None or NoneType
+ if fileobj is not type(None): # noqa
+ fileobj._async_response(t, msg, num)
+ if waitfor is None:
+ # just doing a single check
+ break
+ return None, None
+
+ def _finish_responses(self, fileobj):
+ while fileobj in self._expecting.values():
+ self._read_response()
+ fileobj._check_exception()
def _convert_status(self, msg):
"""
Raises EOFError or IOError on error status; otherwise does nothing.
"""
- pass
+ code = msg.get_int()
+ text = msg.get_text()
+ if code == SFTP_OK:
+ return
+ elif code == SFTP_EOF:
+ raise EOFError(text)
+ elif code == SFTP_NO_SUCH_FILE:
+ # clever idea from john a. meinel: map the error codes to errno
+ raise IOError(errno.ENOENT, text)
+ elif code == SFTP_PERMISSION_DENIED:
+ raise IOError(errno.EACCES, text)
+ else:
+ raise IOError(text)
def _adjust_cwd(self, path):
"""
Return an adjusted path if we're emulating a "current working
directory" for the server.
"""
- pass
+ path = b(path)
+ if self._cwd is None:
+ return path
+ if len(path) and path[0:1] == b_slash:
+ # absolute path
+ return path
+ if self._cwd == b_slash:
+ return self._cwd + path
+ return self._cwd + b_slash + path
class SFTP(SFTPClient):
"""
An alias for `.SFTPClient` for backwards compatibility.
"""
+
pass
diff --git a/paramiko/sftp_file.py b/paramiko/sftp_file.py
index e4ca900d..c74695e0 100644
--- a/paramiko/sftp_file.py
+++ b/paramiko/sftp_file.py
@@ -1,15 +1,48 @@
+# Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
"""
SFTP file object
"""
+
+
from binascii import hexlify
from collections import deque
import socket
import threading
import time
from paramiko.common import DEBUG, io_sleep
+
from paramiko.file import BufferedFile
from paramiko.util import u
-from paramiko.sftp import CMD_CLOSE, CMD_READ, CMD_DATA, SFTPError, CMD_WRITE, CMD_STATUS, CMD_FSTAT, CMD_ATTRS, CMD_FSETSTAT, CMD_EXTENDED, int64
+from paramiko.sftp import (
+ CMD_CLOSE,
+ CMD_READ,
+ CMD_DATA,
+ SFTPError,
+ CMD_WRITE,
+ CMD_STATUS,
+ CMD_FSTAT,
+ CMD_ATTRS,
+ CMD_FSETSTAT,
+ CMD_EXTENDED,
+ int64,
+)
from paramiko.sftp_attr import SFTPAttributes
@@ -20,9 +53,12 @@ class SFTPFile(BufferedFile):
Instances of this class may be used as context managers in the same way
that built-in Python file objects are.
"""
+
+ # Some sftp servers will choke if you send read/write requests larger than
+ # this size.
MAX_REQUEST_SIZE = 32768
- def __init__(self, sftp, handle, mode='r', bufsize=-1):
+ def __init__(self, sftp, handle, mode="r", bufsize=-1):
BufferedFile.__init__(self)
self.sftp = sftp
self.handle = handle
@@ -43,7 +79,55 @@ class SFTPFile(BufferedFile):
"""
Close the file.
"""
- pass
+ self._close(async_=False)
+
+ def _close(self, async_=False):
+ # We allow double-close without signaling an error, because real
+ # Python file objects do. However, we must protect against actually
+ # sending multiple CMD_CLOSE packets, because after we close our
+ # handle, the same handle may be re-allocated by the server, and we
+ # may end up mysteriously closing some random other file. (This is
+ # especially important because we unconditionally call close() from
+ # __del__.)
+ if self._closed:
+ return
+ self.sftp._log(DEBUG, "close({})".format(u(hexlify(self.handle))))
+ if self.pipelined:
+ self.sftp._finish_responses(self)
+ BufferedFile.close(self)
+ try:
+ if async_:
+ # GC'd file handle could be called from an arbitrary thread
+ # -- don't wait for a response
+ self.sftp._async_request(type(None), CMD_CLOSE, self.handle)
+ else:
+ self.sftp._request(CMD_CLOSE, self.handle)
+ except EOFError:
+ # may have outlived the Transport connection
+ pass
+ except (IOError, socket.error):
+ # may have outlived the Transport connection
+ pass
+
+ def _data_in_prefetch_requests(self, offset, size):
+ k = [
+ x for x in list(self._prefetch_extents.values()) if x[0] <= offset
+ ]
+ if len(k) == 0:
+ return False
+ k.sort(key=lambda x: x[0])
+ buf_offset, buf_size = k[-1]
+ if buf_offset + buf_size <= offset:
+ # prefetch request ends before this one begins
+ return False
+ if buf_offset + buf_size >= offset + size:
+ # inclusive
+ return True
+ # well, we have part of the request. see if another chunk has
+ # the rest.
+ return self._data_in_prefetch_requests(
+ buf_offset + buf_size, offset + size - buf_offset - buf_size
+ )
def _data_in_prefetch_buffers(self, offset):
"""
@@ -52,14 +136,80 @@ class SFTPFile(BufferedFile):
return None. this guarantees nothing about the number of bytes
collected in the prefetch buffer so far.
"""
- pass
+ k = [i for i in self._prefetch_data.keys() if i <= offset]
+ if len(k) == 0:
+ return None
+ index = max(k)
+ buf_offset = offset - index
+ if buf_offset >= len(self._prefetch_data[index]):
+ # it's not here
+ return None
+ return index
def _read_prefetch(self, size):
"""
read data out of the prefetch buffer, if possible. if the data isn't
in the buffer, return None. otherwise, behaves like a normal read.
"""
- pass
+ # while not closed, and haven't fetched past the current position,
+ # and haven't reached EOF...
+ while True:
+ offset = self._data_in_prefetch_buffers(self._realpos)
+ if offset is not None:
+ break
+ if self._prefetch_done or self._closed:
+ break
+ self.sftp._read_response()
+ self._check_exception()
+ if offset is None:
+ self._prefetching = False
+ return None
+ prefetch = self._prefetch_data[offset]
+ del self._prefetch_data[offset]
+
+ buf_offset = self._realpos - offset
+ if buf_offset > 0:
+ self._prefetch_data[offset] = prefetch[:buf_offset]
+ prefetch = prefetch[buf_offset:]
+ if size < len(prefetch):
+ self._prefetch_data[self._realpos + size] = prefetch[size:]
+ prefetch = prefetch[:size]
+ return prefetch
+
+ def _read(self, size):
+ size = min(size, self.MAX_REQUEST_SIZE)
+ if self._prefetching:
+ data = self._read_prefetch(size)
+ if data is not None:
+ return data
+ t, msg = self.sftp._request(
+ CMD_READ, self.handle, int64(self._realpos), int(size)
+ )
+ if t != CMD_DATA:
+ raise SFTPError("Expected data")
+ return msg.get_string()
+
+ def _write(self, data):
+ # may write less than requested if it would exceed max packet size
+ chunk = min(len(data), self.MAX_REQUEST_SIZE)
+ sftp_async_request = self.sftp._async_request(
+ type(None),
+ CMD_WRITE,
+ self.handle,
+ int64(self._realpos),
+ data[:chunk],
+ )
+ self._reqs.append(sftp_async_request)
+ if not self.pipelined or (
+ len(self._reqs) > 100 and self.sftp.sock.recv_ready()
+ ):
+ while len(self._reqs):
+ req = self._reqs.popleft()
+ t, msg = self.sftp._read_response(req)
+ if t != CMD_STATUS:
+ raise SFTPError("Expected status")
+ # convert_status already called
+ return chunk
def settimeout(self, timeout):
"""
@@ -72,7 +222,7 @@ class SFTPFile(BufferedFile):
.. seealso:: `.Channel.settimeout`
"""
- pass
+ self.sftp.sock.settimeout(timeout)
def gettimeout(self):
"""
@@ -81,7 +231,7 @@ class SFTPFile(BufferedFile):
.. seealso:: `.Channel.gettimeout`
"""
- pass
+ return self.sftp.sock.gettimeout()
def setblocking(self, blocking):
"""
@@ -93,7 +243,7 @@ class SFTPFile(BufferedFile):
.. seealso:: `.Channel.setblocking`
"""
- pass
+ self.sftp.sock.setblocking(blocking)
def seekable(self):
"""
@@ -103,7 +253,7 @@ class SFTPFile(BufferedFile):
`True` if the file supports random access. If `False`,
:meth:`seek` will raise an exception
"""
- pass
+ return True
def seek(self, offset, whence=0):
"""
@@ -111,7 +261,15 @@ class SFTPFile(BufferedFile):
See `file.seek` for details.
"""
- pass
+ self.flush()
+ if whence == self.SEEK_SET:
+ self._realpos = self._pos = offset
+ elif whence == self.SEEK_CUR:
+ self._pos += offset
+ self._realpos = self._pos
+ else:
+ self._realpos = self._pos = self._get_size() + offset
+ self._rbuffer = bytes()
def stat(self):
"""
@@ -122,7 +280,10 @@ class SFTPFile(BufferedFile):
:returns:
an `.SFTPAttributes` object containing attributes about this file.
"""
- pass
+ t, msg = self.sftp._request(CMD_FSTAT, self.handle)
+ if t != CMD_ATTRS:
+ raise SFTPError("Expected attributes")
+ return SFTPAttributes._from_msg(msg)
def chmod(self, mode):
"""
@@ -132,7 +293,12 @@ class SFTPFile(BufferedFile):
:param int mode: new permissions
"""
- pass
+ self.sftp._log(
+ DEBUG, "chmod({}, {!r})".format(hexlify(self.handle), mode)
+ )
+ attr = SFTPAttributes()
+ attr.st_mode = mode
+ self.sftp._request(CMD_FSETSTAT, self.handle, attr)
def chown(self, uid, gid):
"""
@@ -144,7 +310,13 @@ class SFTPFile(BufferedFile):
:param int uid: new owner's uid
:param int gid: new group id
"""
- pass
+ self.sftp._log(
+ DEBUG,
+ "chown({}, {!r}, {!r})".format(hexlify(self.handle), uid, gid),
+ )
+ attr = SFTPAttributes()
+ attr.st_uid, attr.st_gid = uid, gid
+ self.sftp._request(CMD_FSETSTAT, self.handle, attr)
def utime(self, times):
"""
@@ -159,7 +331,14 @@ class SFTPFile(BufferedFile):
``None`` or a tuple of (access time, modified time) in standard
internet epoch time (seconds since 01 January 1970 GMT)
"""
- pass
+ if times is None:
+ times = (time.time(), time.time())
+ self.sftp._log(
+ DEBUG, "utime({}, {!r})".format(hexlify(self.handle), times)
+ )
+ attr = SFTPAttributes()
+ attr.st_atime, attr.st_mtime = times
+ self.sftp._request(CMD_FSETSTAT, self.handle, attr)
def truncate(self, size):
"""
@@ -169,7 +348,12 @@ class SFTPFile(BufferedFile):
:param size: the new size of the file
"""
- pass
+ self.sftp._log(
+ DEBUG, "truncate({}, {!r})".format(hexlify(self.handle), size)
+ )
+ attr = SFTPAttributes()
+ attr.st_size = size
+ self.sftp._request(CMD_FSETSTAT, self.handle, attr)
def check(self, hash_algorithm, offset=0, length=0, block_size=0):
"""
@@ -217,7 +401,19 @@ class SFTPFile(BufferedFile):
.. versionadded:: 1.4
"""
- pass
+ t, msg = self.sftp._request(
+ CMD_EXTENDED,
+ "check-file",
+ self.handle,
+ hash_algorithm,
+ int64(offset),
+ int64(length),
+ block_size,
+ )
+ msg.get_text() # ext
+ msg.get_text() # alg
+ data = msg.get_remainder()
+ return data
def set_pipelined(self, pipelined=True):
"""
@@ -237,7 +433,7 @@ class SFTPFile(BufferedFile):
.. versionadded:: 1.5
"""
- pass
+ self.pipelined = pipelined
def prefetch(self, file_size=None, max_concurrent_requests=None):
"""
@@ -272,7 +468,18 @@ class SFTPFile(BufferedFile):
.. versionchanged:: 3.3
Added ``max_concurrent_requests``.
"""
- pass
+ if file_size is None:
+ file_size = self.stat().st_size
+
+ # queue up async reads for the rest of the file
+ chunks = []
+ n = self._realpos
+ while n < file_size:
+ chunk = min(self.MAX_REQUEST_SIZE, file_size - n)
+ chunks.append((n, chunk))
+ n += chunk
+ if len(chunks) > 0:
+ self._start_prefetch(chunks, max_concurrent_requests)
def readv(self, chunks, max_concurrent_prefetch_requests=None):
"""
@@ -294,8 +501,94 @@ class SFTPFile(BufferedFile):
.. versionchanged:: 3.3
Added ``max_concurrent_prefetch_requests``.
"""
- pass
+ self.sftp._log(
+ DEBUG, "readv({}, {!r})".format(hexlify(self.handle), chunks)
+ )
+
+ read_chunks = []
+ for offset, size in chunks:
+ # don't fetch data that's already in the prefetch buffer
+ if self._data_in_prefetch_buffers(
+ offset
+ ) or self._data_in_prefetch_requests(offset, size):
+ continue
+
+ # break up anything larger than the max read size
+ while size > 0:
+ chunk_size = min(size, self.MAX_REQUEST_SIZE)
+ read_chunks.append((offset, chunk_size))
+ offset += chunk_size
+ size -= chunk_size
+
+ self._start_prefetch(read_chunks, max_concurrent_prefetch_requests)
+ # now we can just devolve to a bunch of read()s :)
+ for x in chunks:
+ self.seek(x[0])
+ yield self.read(x[1])
+
+ # ...internals...
+
+ def _get_size(self):
+ try:
+ return self.stat().st_size
+ except:
+ return 0
+
+ def _start_prefetch(self, chunks, max_concurrent_requests=None):
+ self._prefetching = True
+ self._prefetch_done = False
+
+ t = threading.Thread(
+ target=self._prefetch_thread,
+ args=(chunks, max_concurrent_requests),
+ )
+ t.daemon = True
+ t.start()
+
+ def _prefetch_thread(self, chunks, max_concurrent_requests):
+ # do these read requests in a temporary thread because there may be
+ # a lot of them, so it may block.
+ for offset, length in chunks:
+ # Limit the number of concurrent requests in a busy-loop
+ if max_concurrent_requests is not None:
+ while True:
+ with self._prefetch_lock:
+ pf_len = len(self._prefetch_extents)
+ if pf_len < max_concurrent_requests:
+ break
+ time.sleep(io_sleep)
+
+ num = self.sftp._async_request(
+ self, CMD_READ, self.handle, int64(offset), int(length)
+ )
+ with self._prefetch_lock:
+ self._prefetch_extents[num] = (offset, length)
+
+ def _async_response(self, t, msg, num):
+ if t == CMD_STATUS:
+ # save exception and re-raise it on next file operation
+ try:
+ self.sftp._convert_status(msg)
+ except Exception as e:
+ self._saved_exception = e
+ return
+ if t != CMD_DATA:
+ raise SFTPError("Expected data")
+ data = msg.get_string()
+ while True:
+ with self._prefetch_lock:
+ # spin if in race with _prefetch_thread
+ if num in self._prefetch_extents:
+ offset, length = self._prefetch_extents[num]
+ self._prefetch_data[offset] = data
+ del self._prefetch_extents[num]
+ if len(self._prefetch_extents) == 0:
+ self._prefetch_done = True
+ break
def _check_exception(self):
"""if there's a saved exception, raise & clear it"""
- pass
+ if self._saved_exception is not None:
+ x = self._saved_exception
+ self._saved_exception = None
+ raise x
diff --git a/paramiko/sftp_handle.py b/paramiko/sftp_handle.py
index 5b9d4a8b..b2046526 100644
--- a/paramiko/sftp_handle.py
+++ b/paramiko/sftp_handle.py
@@ -1,6 +1,25 @@
+# Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
"""
Abstraction of an SFTP file handle (for server mode).
"""
+
import os
from paramiko.sftp import SFTP_OP_UNSUPPORTED, SFTP_OK
from paramiko.util import ClosingContextManager
@@ -29,6 +48,7 @@ class SFTPHandle(ClosingContextManager):
"""
self.__flags = flags
self.__name = None
+ # only for handles to folders:
self.__files = {}
self.__tell = None
@@ -44,7 +64,12 @@ class SFTPHandle(ClosingContextManager):
using the default implementations of `read` and `write`, this
method's default implementation should be fine also.
"""
- pass
+ readfile = getattr(self, "readfile", None)
+ if readfile is not None:
+ readfile.close()
+ writefile = getattr(self, "writefile", None)
+ if writefile is not None:
+ writefile.close()
def read(self, offset, length):
"""
@@ -64,7 +89,21 @@ class SFTPHandle(ClosingContextManager):
:param int length: number of bytes to attempt to read.
:return: the `bytes` read, or an error code `int`.
"""
- pass
+ readfile = getattr(self, "readfile", None)
+ if readfile is None:
+ return SFTP_OP_UNSUPPORTED
+ try:
+ if self.__tell is None:
+ self.__tell = readfile.tell()
+ if offset != self.__tell:
+ readfile.seek(offset)
+ self.__tell = offset
+ data = readfile.read(length)
+ except IOError as e:
+ self.__tell = None
+ return SFTPServer.convert_errno(e.errno)
+ self.__tell += len(data)
+ return data
def write(self, offset, data):
"""
@@ -84,7 +123,25 @@ class SFTPHandle(ClosingContextManager):
:param bytes data: data to write into the file.
:return: an SFTP error code like ``SFTP_OK``.
"""
- pass
+ writefile = getattr(self, "writefile", None)
+ if writefile is None:
+ return SFTP_OP_UNSUPPORTED
+ try:
+ # in append mode, don't care about seeking
+ if (self.__flags & os.O_APPEND) == 0:
+ if self.__tell is None:
+ self.__tell = writefile.tell()
+ if offset != self.__tell:
+ writefile.seek(offset)
+ self.__tell = offset
+ writefile.write(data)
+ writefile.flush()
+ except IOError as e:
+ self.__tell = None
+ return SFTPServer.convert_errno(e.errno)
+ if self.__tell is not None:
+ self.__tell += len(data)
+ return SFTP_OK
def stat(self):
"""
@@ -97,7 +154,7 @@ class SFTPHandle(ClosingContextManager):
(like ``SFTP_PERMISSION_DENIED``).
:rtype: `.SFTPAttributes` or error code
"""
- pass
+ return SFTP_OP_UNSUPPORTED
def chattr(self, attr):
"""
@@ -108,7 +165,9 @@ class SFTPHandle(ClosingContextManager):
:param .SFTPAttributes attr: the attributes to change on this file.
:return: an `int` error code like ``SFTP_OK``.
"""
- pass
+ return SFTP_OP_UNSUPPORTED
+
+ # ...internals...
def _set_files(self, files):
"""
@@ -116,14 +175,22 @@ class SFTPHandle(ClosingContextManager):
the SFTP protocol, listing a directory is a multi-stage process
requiring a temporary handle.)
"""
- pass
+ self.__files = files
def _get_next_files(self):
"""
Used by the SFTP server code to retrieve a cached directory
listing.
"""
- pass
+ fnlist = self.__files[:16]
+ self.__files = self.__files[16:]
+ return fnlist
+
+ def _get_name(self):
+ return self.__name
+
+ def _set_name(self, name):
+ self.__name = name
from paramiko.sftp_server import SFTPServer
diff --git a/paramiko/sftp_server.py b/paramiko/sftp_server.py
index 2ffa92dd..cd3910dc 100644
--- a/paramiko/sftp_server.py
+++ b/paramiko/sftp_server.py
@@ -1,19 +1,88 @@
+# Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
"""
Server-mode SFTP support.
"""
+
import os
import errno
import sys
from hashlib import md5, sha1
+
from paramiko import util
-from paramiko.sftp import BaseSFTP, Message, SFTP_FAILURE, SFTP_PERMISSION_DENIED, SFTP_NO_SUCH_FILE, int64
+from paramiko.sftp import (
+ BaseSFTP,
+ Message,
+ SFTP_FAILURE,
+ SFTP_PERMISSION_DENIED,
+ SFTP_NO_SUCH_FILE,
+ int64,
+)
from paramiko.sftp_si import SFTPServerInterface
from paramiko.sftp_attr import SFTPAttributes
from paramiko.common import DEBUG
from paramiko.server import SubsystemHandler
from paramiko.util import b
-from paramiko.sftp import CMD_HANDLE, SFTP_DESC, CMD_STATUS, SFTP_EOF, CMD_NAME, SFTP_BAD_MESSAGE, CMD_EXTENDED_REPLY, SFTP_FLAG_READ, SFTP_FLAG_WRITE, SFTP_FLAG_APPEND, SFTP_FLAG_CREATE, SFTP_FLAG_TRUNC, SFTP_FLAG_EXCL, CMD_NAMES, CMD_OPEN, CMD_CLOSE, SFTP_OK, CMD_READ, CMD_DATA, CMD_WRITE, CMD_REMOVE, CMD_RENAME, CMD_MKDIR, CMD_RMDIR, CMD_OPENDIR, CMD_READDIR, CMD_STAT, CMD_ATTRS, CMD_LSTAT, CMD_FSTAT, CMD_SETSTAT, CMD_FSETSTAT, CMD_READLINK, CMD_SYMLINK, CMD_REALPATH, CMD_EXTENDED, SFTP_OP_UNSUPPORTED
-_hash_class = {'sha1': sha1, 'md5': md5}
+
+
+# known hash algorithms for the "check-file" extension
+from paramiko.sftp import (
+ CMD_HANDLE,
+ SFTP_DESC,
+ CMD_STATUS,
+ SFTP_EOF,
+ CMD_NAME,
+ SFTP_BAD_MESSAGE,
+ CMD_EXTENDED_REPLY,
+ SFTP_FLAG_READ,
+ SFTP_FLAG_WRITE,
+ SFTP_FLAG_APPEND,
+ SFTP_FLAG_CREATE,
+ SFTP_FLAG_TRUNC,
+ SFTP_FLAG_EXCL,
+ CMD_NAMES,
+ CMD_OPEN,
+ CMD_CLOSE,
+ SFTP_OK,
+ CMD_READ,
+ CMD_DATA,
+ CMD_WRITE,
+ CMD_REMOVE,
+ CMD_RENAME,
+ CMD_MKDIR,
+ CMD_RMDIR,
+ CMD_OPENDIR,
+ CMD_READDIR,
+ CMD_STAT,
+ CMD_ATTRS,
+ CMD_LSTAT,
+ CMD_FSTAT,
+ CMD_SETSTAT,
+ CMD_FSETSTAT,
+ CMD_READLINK,
+ CMD_SYMLINK,
+ CMD_REALPATH,
+ CMD_EXTENDED,
+ SFTP_OP_UNSUPPORTED,
+)
+
+_hash_class = {"sha1": sha1, "md5": md5}
class SFTPServer(BaseSFTP, SubsystemHandler):
@@ -23,8 +92,15 @@ class SFTPServer(BaseSFTP, SubsystemHandler):
Use `.Transport.set_subsystem_handler` to activate this class.
"""
- def __init__(self, channel, name, server, sftp_si=SFTPServerInterface,
- *args, **kwargs):
+ def __init__(
+ self,
+ channel,
+ name,
+ server,
+ sftp_si=SFTPServerInterface,
+ *args,
+ **kwargs
+ ):
"""
The constructor for SFTPServer is meant to be called from within the
`.Transport` as a subsystem handler. ``server`` and any additional
@@ -42,13 +118,61 @@ class SFTPServer(BaseSFTP, SubsystemHandler):
BaseSFTP.__init__(self)
SubsystemHandler.__init__(self, channel, name, server)
transport = channel.get_transport()
- self.logger = util.get_logger(transport.get_log_channel() + '.sftp')
+ self.logger = util.get_logger(transport.get_log_channel() + ".sftp")
self.ultra_debug = transport.get_hexdump()
self.next_handle = 1
+ # map of handle-string to SFTPHandle for files & folders:
self.file_table = {}
self.folder_table = {}
self.server = sftp_si(server, *args, **kwargs)
+ def _log(self, level, msg):
+ if issubclass(type(msg), list):
+ for m in msg:
+ super()._log(level, "[chan " + self.sock.get_name() + "] " + m)
+ else:
+ super()._log(level, "[chan " + self.sock.get_name() + "] " + msg)
+
+ def start_subsystem(self, name, transport, channel):
+ self.sock = channel
+ self._log(DEBUG, "Started sftp server on channel {!r}".format(channel))
+ self._send_server_version()
+ self.server.session_started()
+ while True:
+ try:
+ t, data = self._read_packet()
+ except EOFError:
+ self._log(DEBUG, "EOF -- end of session")
+ return
+ except Exception as e:
+ self._log(DEBUG, "Exception on channel: " + str(e))
+ self._log(DEBUG, util.tb_strings())
+ return
+ msg = Message(data)
+ request_number = msg.get_int()
+ try:
+ self._process(t, request_number, msg)
+ except Exception as e:
+ self._log(DEBUG, "Exception in server processing: " + str(e))
+ self._log(DEBUG, util.tb_strings())
+ # send some kind of failure message, at least
+ try:
+ self._send_status(request_number, SFTP_FAILURE)
+ except:
+ pass
+
+ def finish_subsystem(self):
+ self.server.session_ended()
+ super().finish_subsystem()
+ # close any file handles that were left open
+ # (so we can return them to the OS quickly)
+ for f in self.file_table.values():
+ f.close()
+ for f in self.folder_table.values():
+ f.close()
+ self.file_table = {}
+ self.folder_table = {}
+
@staticmethod
def convert_errno(e):
"""
@@ -59,7 +183,14 @@ class SFTPServer(BaseSFTP, SubsystemHandler):
:param int e: an errno code, as from ``OSError.errno``.
:return: an `int` SFTP error code like ``SFTP_NO_SUCH_FILE``.
"""
- pass
+ if e == errno.EACCES:
+ # permission denied
+ return SFTP_PERMISSION_DENIED
+ elif (e == errno.ENOENT) or (e == errno.ENOTDIR):
+ # no such file
+ return SFTP_NO_SUCH_FILE
+ else:
+ return SFTP_FAILURE
@staticmethod
def set_file_attr(filename, attr):
@@ -76,11 +207,331 @@ class SFTPServer(BaseSFTP, SubsystemHandler):
name of the file to alter (should usually be an absolute path).
:param .SFTPAttributes attr: attributes to change.
"""
- pass
+ if sys.platform != "win32":
+ # mode operations are meaningless on win32
+ if attr._flags & attr.FLAG_PERMISSIONS:
+ os.chmod(filename, attr.st_mode)
+ if attr._flags & attr.FLAG_UIDGID:
+ os.chown(filename, attr.st_uid, attr.st_gid)
+ if attr._flags & attr.FLAG_AMTIME:
+ os.utime(filename, (attr.st_atime, attr.st_mtime))
+ if attr._flags & attr.FLAG_SIZE:
+ with open(filename, "w+") as f:
+ f.truncate(attr.st_size)
+
+ # ...internals...
+
+ def _response(self, request_number, t, *args):
+ msg = Message()
+ msg.add_int(request_number)
+ for item in args:
+ # NOTE: this is a very silly tiny class used for SFTPFile mostly
+ if isinstance(item, int64):
+ msg.add_int64(item)
+ elif isinstance(item, int):
+ msg.add_int(item)
+ elif isinstance(item, (str, bytes)):
+ msg.add_string(item)
+ elif type(item) is SFTPAttributes:
+ item._pack(msg)
+ else:
+ raise Exception(
+ "unknown type for {!r} type {!r}".format(item, type(item))
+ )
+ self._send_packet(t, msg)
+
+ def _send_handle_response(self, request_number, handle, folder=False):
+ if not issubclass(type(handle), SFTPHandle):
+ # must be error code
+ self._send_status(request_number, handle)
+ return
+ handle._set_name(b("hx{:d}".format(self.next_handle)))
+ self.next_handle += 1
+ if folder:
+ self.folder_table[handle._get_name()] = handle
+ else:
+ self.file_table[handle._get_name()] = handle
+ self._response(request_number, CMD_HANDLE, handle._get_name())
+
+ def _send_status(self, request_number, code, desc=None):
+ if desc is None:
+ try:
+ desc = SFTP_DESC[code]
+ except IndexError:
+ desc = "Unknown"
+ # some clients expect a "language" tag at the end
+ # (but don't mind it being blank)
+ self._response(request_number, CMD_STATUS, code, desc, "")
+
+ def _open_folder(self, request_number, path):
+ resp = self.server.list_folder(path)
+ if issubclass(type(resp), list):
+ # got an actual list of filenames in the folder
+ folder = SFTPHandle()
+ folder._set_files(resp)
+ self._send_handle_response(request_number, folder, True)
+ return
+ # must be an error code
+ self._send_status(request_number, resp)
+
+ def _read_folder(self, request_number, folder):
+ flist = folder._get_next_files()
+ if len(flist) == 0:
+ self._send_status(request_number, SFTP_EOF)
+ return
+ msg = Message()
+ msg.add_int(request_number)
+ msg.add_int(len(flist))
+ for attr in flist:
+ msg.add_string(attr.filename)
+ msg.add_string(attr)
+ attr._pack(msg)
+ self._send_packet(CMD_NAME, msg)
+
+ def _check_file(self, request_number, msg):
+ # this extension actually comes from v6 protocol, but since it's an
+ # extension, i feel like we can reasonably support it backported.
+ # it's very useful for verifying uploaded files or checking for
+ # rsync-like differences between local and remote files.
+ handle = msg.get_binary()
+ alg_list = msg.get_list()
+ start = msg.get_int64()
+ length = msg.get_int64()
+ block_size = msg.get_int()
+ if handle not in self.file_table:
+ self._send_status(
+ request_number, SFTP_BAD_MESSAGE, "Invalid handle"
+ )
+ return
+ f = self.file_table[handle]
+ for x in alg_list:
+ if x in _hash_class:
+ algname = x
+ alg = _hash_class[x]
+ break
+ else:
+ self._send_status(
+ request_number, SFTP_FAILURE, "No supported hash types found"
+ )
+ return
+ if length == 0:
+ st = f.stat()
+ if not issubclass(type(st), SFTPAttributes):
+ self._send_status(request_number, st, "Unable to stat file")
+ return
+ length = st.st_size - start
+ if block_size == 0:
+ block_size = length
+ if block_size < 256:
+ self._send_status(
+ request_number, SFTP_FAILURE, "Block size too small"
+ )
+ return
+
+ sum_out = bytes()
+ offset = start
+ while offset < start + length:
+ blocklen = min(block_size, start + length - offset)
+ # don't try to read more than about 64KB at a time
+ chunklen = min(blocklen, 65536)
+ count = 0
+ hash_obj = alg()
+ while count < blocklen:
+ data = f.read(offset, chunklen)
+ if not isinstance(data, bytes):
+ self._send_status(
+ request_number, data, "Unable to hash file"
+ )
+ return
+ hash_obj.update(data)
+ count += len(data)
+ offset += count
+ sum_out += hash_obj.digest()
+
+ msg = Message()
+ msg.add_int(request_number)
+ msg.add_string("check-file")
+ msg.add_string(algname)
+ msg.add_bytes(sum_out)
+ self._send_packet(CMD_EXTENDED_REPLY, msg)
def _convert_pflags(self, pflags):
"""convert SFTP-style open() flags to Python's os.open() flags"""
- pass
+ if (pflags & SFTP_FLAG_READ) and (pflags & SFTP_FLAG_WRITE):
+ flags = os.O_RDWR
+ elif pflags & SFTP_FLAG_WRITE:
+ flags = os.O_WRONLY
+ else:
+ flags = os.O_RDONLY
+ if pflags & SFTP_FLAG_APPEND:
+ flags |= os.O_APPEND
+ if pflags & SFTP_FLAG_CREATE:
+ flags |= os.O_CREAT
+ if pflags & SFTP_FLAG_TRUNC:
+ flags |= os.O_TRUNC
+ if pflags & SFTP_FLAG_EXCL:
+ flags |= os.O_EXCL
+ return flags
+
+ def _process(self, t, request_number, msg):
+ self._log(DEBUG, "Request: {}".format(CMD_NAMES[t]))
+ if t == CMD_OPEN:
+ path = msg.get_text()
+ flags = self._convert_pflags(msg.get_int())
+ attr = SFTPAttributes._from_msg(msg)
+ self._send_handle_response(
+ request_number, self.server.open(path, flags, attr)
+ )
+ elif t == CMD_CLOSE:
+ handle = msg.get_binary()
+ if handle in self.folder_table:
+ del self.folder_table[handle]
+ self._send_status(request_number, SFTP_OK)
+ return
+ if handle in self.file_table:
+ self.file_table[handle].close()
+ del self.file_table[handle]
+ self._send_status(request_number, SFTP_OK)
+ return
+ self._send_status(
+ request_number, SFTP_BAD_MESSAGE, "Invalid handle"
+ )
+ elif t == CMD_READ:
+ handle = msg.get_binary()
+ offset = msg.get_int64()
+ length = msg.get_int()
+ if handle not in self.file_table:
+ self._send_status(
+ request_number, SFTP_BAD_MESSAGE, "Invalid handle"
+ )
+ return
+ data = self.file_table[handle].read(offset, length)
+ if isinstance(data, (bytes, str)):
+ if len(data) == 0:
+ self._send_status(request_number, SFTP_EOF)
+ else:
+ self._response(request_number, CMD_DATA, data)
+ else:
+ self._send_status(request_number, data)
+ elif t == CMD_WRITE:
+ handle = msg.get_binary()
+ offset = msg.get_int64()
+ data = msg.get_binary()
+ if handle not in self.file_table:
+ self._send_status(
+ request_number, SFTP_BAD_MESSAGE, "Invalid handle"
+ )
+ return
+ self._send_status(
+ request_number, self.file_table[handle].write(offset, data)
+ )
+ elif t == CMD_REMOVE:
+ path = msg.get_text()
+ self._send_status(request_number, self.server.remove(path))
+ elif t == CMD_RENAME:
+ oldpath = msg.get_text()
+ newpath = msg.get_text()
+ self._send_status(
+ request_number, self.server.rename(oldpath, newpath)
+ )
+ elif t == CMD_MKDIR:
+ path = msg.get_text()
+ attr = SFTPAttributes._from_msg(msg)
+ self._send_status(request_number, self.server.mkdir(path, attr))
+ elif t == CMD_RMDIR:
+ path = msg.get_text()
+ self._send_status(request_number, self.server.rmdir(path))
+ elif t == CMD_OPENDIR:
+ path = msg.get_text()
+ self._open_folder(request_number, path)
+ return
+ elif t == CMD_READDIR:
+ handle = msg.get_binary()
+ if handle not in self.folder_table:
+ self._send_status(
+ request_number, SFTP_BAD_MESSAGE, "Invalid handle"
+ )
+ return
+ folder = self.folder_table[handle]
+ self._read_folder(request_number, folder)
+ elif t == CMD_STAT:
+ path = msg.get_text()
+ resp = self.server.stat(path)
+ if issubclass(type(resp), SFTPAttributes):
+ self._response(request_number, CMD_ATTRS, resp)
+ else:
+ self._send_status(request_number, resp)
+ elif t == CMD_LSTAT:
+ path = msg.get_text()
+ resp = self.server.lstat(path)
+ if issubclass(type(resp), SFTPAttributes):
+ self._response(request_number, CMD_ATTRS, resp)
+ else:
+ self._send_status(request_number, resp)
+ elif t == CMD_FSTAT:
+ handle = msg.get_binary()
+ if handle not in self.file_table:
+ self._send_status(
+ request_number, SFTP_BAD_MESSAGE, "Invalid handle"
+ )
+ return
+ resp = self.file_table[handle].stat()
+ if issubclass(type(resp), SFTPAttributes):
+ self._response(request_number, CMD_ATTRS, resp)
+ else:
+ self._send_status(request_number, resp)
+ elif t == CMD_SETSTAT:
+ path = msg.get_text()
+ attr = SFTPAttributes._from_msg(msg)
+ self._send_status(request_number, self.server.chattr(path, attr))
+ elif t == CMD_FSETSTAT:
+ handle = msg.get_binary()
+ attr = SFTPAttributes._from_msg(msg)
+ if handle not in self.file_table:
+ self._response(
+ request_number, SFTP_BAD_MESSAGE, "Invalid handle"
+ )
+ return
+ self._send_status(
+ request_number, self.file_table[handle].chattr(attr)
+ )
+ elif t == CMD_READLINK:
+ path = msg.get_text()
+ resp = self.server.readlink(path)
+ if isinstance(resp, (bytes, str)):
+ self._response(
+ request_number, CMD_NAME, 1, resp, "", SFTPAttributes()
+ )
+ else:
+ self._send_status(request_number, resp)
+ elif t == CMD_SYMLINK:
+ # the sftp 2 draft is incorrect here!
+ # path always follows target_path
+ target_path = msg.get_text()
+ path = msg.get_text()
+ self._send_status(
+ request_number, self.server.symlink(target_path, path)
+ )
+ elif t == CMD_REALPATH:
+ path = msg.get_text()
+ rpath = self.server.canonicalize(path)
+ self._response(
+ request_number, CMD_NAME, 1, rpath, "", SFTPAttributes()
+ )
+ elif t == CMD_EXTENDED:
+ tag = msg.get_text()
+ if tag == "check-file":
+ self._check_file(request_number, msg)
+ elif tag == "posix-rename@openssh.com":
+ oldpath = msg.get_text()
+ newpath = msg.get_text()
+ self._send_status(
+ request_number, self.server.posix_rename(oldpath, newpath)
+ )
+ else:
+ self._send_status(request_number, SFTP_OP_UNSUPPORTED)
+ else:
+ self._send_status(request_number, SFTP_OP_UNSUPPORTED)
from paramiko.sftp_handle import SFTPHandle
diff --git a/paramiko/sftp_si.py b/paramiko/sftp_si.py
index e0b4e643..72b5db94 100644
--- a/paramiko/sftp_si.py
+++ b/paramiko/sftp_si.py
@@ -1,6 +1,25 @@
+# Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
"""
An interface to override for SFTP server support.
"""
+
import os
import sys
from paramiko.sftp import SFTP_OP_UNSUPPORTED
@@ -86,7 +105,7 @@ class SFTPServerInterface:
requested attributes of the file if it is newly created.
:return: a new `.SFTPHandle` or error code.
"""
- pass
+ return SFTP_OP_UNSUPPORTED
def list_folder(self, path):
"""
@@ -118,7 +137,7 @@ class SFTPServerInterface:
direct translation from the SFTP server path to your local
filesystem.
"""
- pass
+ return SFTP_OP_UNSUPPORTED
def stat(self, path):
"""
@@ -134,7 +153,7 @@ class SFTPServerInterface:
an `.SFTPAttributes` object for the given file, or an SFTP error
code (like ``SFTP_PERMISSION_DENIED``).
"""
- pass
+ return SFTP_OP_UNSUPPORTED
def lstat(self, path):
"""
@@ -152,7 +171,7 @@ class SFTPServerInterface:
an `.SFTPAttributes` object for the given file, or an SFTP error
code (like ``SFTP_PERMISSION_DENIED``).
"""
- pass
+ return SFTP_OP_UNSUPPORTED
def remove(self, path):
"""
@@ -162,7 +181,7 @@ class SFTPServerInterface:
the requested path (relative or absolute) of the file to delete.
:return: an SFTP error code `int` like ``SFTP_OK``.
"""
- pass
+ return SFTP_OP_UNSUPPORTED
def rename(self, oldpath, newpath):
"""
@@ -186,7 +205,7 @@ class SFTPServerInterface:
:param str newpath: the requested new path of the file.
:return: an SFTP error code `int` like ``SFTP_OK``.
"""
- pass
+ return SFTP_OP_UNSUPPORTED
def posix_rename(self, oldpath, newpath):
"""
@@ -200,7 +219,7 @@ class SFTPServerInterface:
:versionadded: 2.2
"""
- pass
+ return SFTP_OP_UNSUPPORTED
def mkdir(self, path, attr):
"""
@@ -217,7 +236,7 @@ class SFTPServerInterface:
:param .SFTPAttributes attr: requested attributes of the new folder.
:return: an SFTP error code `int` like ``SFTP_OK``.
"""
- pass
+ return SFTP_OP_UNSUPPORTED
def rmdir(self, path):
"""
@@ -229,7 +248,7 @@ class SFTPServerInterface:
requested path (relative or absolute) of the folder to remove.
:return: an SFTP error code `int` like ``SFTP_OK``.
"""
- pass
+ return SFTP_OP_UNSUPPORTED
def chattr(self, path, attr):
"""
@@ -244,7 +263,7 @@ class SFTPServerInterface:
object)
:return: an error code `int` like ``SFTP_OK``.
"""
- pass
+ return SFTP_OP_UNSUPPORTED
def canonicalize(self, path):
"""
@@ -260,7 +279,14 @@ class SFTPServerInterface:
The default implementation returns ``os.path.normpath('/' + path)``.
"""
- pass
+ if os.path.isabs(path):
+ out = os.path.normpath(path)
+ else:
+ out = os.path.normpath("/" + path)
+ if sys.platform == "win32":
+ # on windows, normalize backslashes to sftp/posix format
+ out = out.replace("\\", "/")
+ return out
def readlink(self, path):
"""
@@ -273,7 +299,7 @@ class SFTPServerInterface:
the target `str` path of the symbolic link, or an error code like
``SFTP_NO_SUCH_FILE``.
"""
- pass
+ return SFTP_OP_UNSUPPORTED
def symlink(self, target_path, path):
"""
@@ -287,4 +313,4 @@ class SFTPServerInterface:
path (relative or absolute) of the symbolic link to create.
:return: an error code `int` like ``SFTP_OK``.
"""
- pass
+ return SFTP_OP_UNSUPPORTED
diff --git a/paramiko/ssh_exception.py b/paramiko/ssh_exception.py
index f09c6eb0..2b68ebe8 100644
--- a/paramiko/ssh_exception.py
+++ b/paramiko/ssh_exception.py
@@ -1,3 +1,21 @@
+# Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
import socket
@@ -5,6 +23,7 @@ class SSHException(Exception):
"""
Exception raised by failures in SSH2 protocol negotiation or logic errors.
"""
+
pass
@@ -16,6 +35,7 @@ class AuthenticationException(SSHException):
.. versionadded:: 1.6
"""
+
pass
@@ -23,6 +43,7 @@ class PasswordRequiredException(AuthenticationException):
"""
Exception raised when a password is needed to unlock a private key file.
"""
+
pass
@@ -34,22 +55,28 @@ class BadAuthenticationType(AuthenticationException):
.. versionadded:: 1.1
"""
+
allowed_types = []
+ # TODO 4.0: remove explanation kwarg
def __init__(self, explanation, types):
+ # TODO 4.0: remove this supercall unless it's actually required for
+ # pickling (after fixing pickling)
AuthenticationException.__init__(self, explanation, types)
self.explanation = explanation
self.allowed_types = types
def __str__(self):
- return '{}; allowed types: {!r}'.format(self.explanation, self.
- allowed_types)
+ return "{}; allowed types: {!r}".format(
+ self.explanation, self.allowed_types
+ )
class PartialAuthentication(AuthenticationException):
"""
An internal exception thrown in the case of partial authentication.
"""
+
allowed_types = []
def __init__(self, types):
@@ -57,10 +84,12 @@ class PartialAuthentication(AuthenticationException):
self.allowed_types = types
def __str__(self):
- return 'Partial authentication; allowed types: {!r}'.format(self.
- allowed_types)
+ return "Partial authentication; allowed types: {!r}".format(
+ self.allowed_types
+ )
+# TODO 4.0: stop inheriting from SSHException, move to auth.py
class UnableToAuthenticate(AuthenticationException):
pass
@@ -80,7 +109,7 @@ class ChannelException(SSHException):
self.text = text
def __str__(self):
- return 'ChannelException({!r}, {!r})'.format(self.code, self.text)
+ return "ChannelException({!r}, {!r})".format(self.code, self.text)
class BadHostKeyException(SSHException):
@@ -101,10 +130,12 @@ class BadHostKeyException(SSHException):
self.expected_key = expected_key
def __str__(self):
- msg = (
- "Host key for server '{}' does not match: got '{}', expected '{}'")
- return msg.format(self.hostname, self.key.get_base64(), self.
- expected_key.get_base64())
+ msg = "Host key for server '{}' does not match: got '{}', expected '{}'" # noqa
+ return msg.format(
+ self.hostname,
+ self.key.get_base64(),
+ self.expected_key.get_base64(),
+ )
class IncompatiblePeer(SSHException):
@@ -113,6 +144,12 @@ class IncompatiblePeer(SSHException):
.. versionadded:: 2.9
"""
+
+ # TODO 4.0: consider making this annotate w/ 1..N 'missing' algorithms,
+ # either just the first one that would halt kex, or even updating the
+ # Transport logic so we record /all/ that /could/ halt kex.
+ # TODO: update docstrings where this may end up raised so they are more
+ # specific.
pass
@@ -131,7 +168,8 @@ class ProxyCommandFailure(SSHException):
def __str__(self):
return 'ProxyCommand("{}") returned nonzero exit status: {}'.format(
- self.command, self.error)
+ self.command, self.error
+ )
class NoValidConnectionsError(socket.error):
@@ -163,17 +201,19 @@ class NoValidConnectionsError(socket.error):
The errors dict to store, as described by class docstring.
"""
addrs = sorted(errors.keys())
- body = ', '.join([x[0] for x in addrs[:-1]])
+ body = ", ".join([x[0] for x in addrs[:-1]])
tail = addrs[-1][0]
if body:
- msg = 'Unable to connect to port {0} on {1} or {2}'
+ msg = "Unable to connect to port {0} on {1} or {2}"
else:
- msg = 'Unable to connect to port {0} on {2}'
- super().__init__(None, msg.format(addrs[0][1], body, tail))
+ msg = "Unable to connect to port {0} on {2}"
+ super().__init__(
+ None, msg.format(addrs[0][1], body, tail) # stand-in for errno
+ )
self.errors = errors
def __reduce__(self):
- return self.__class__, (self.errors,)
+ return (self.__class__, (self.errors,))
class CouldNotCanonicalize(SSHException):
@@ -182,6 +222,7 @@ class CouldNotCanonicalize(SSHException):
.. versionadded:: 2.7
"""
+
pass
@@ -195,6 +236,7 @@ class ConfigParseError(SSHException):
.. versionadded:: 2.7
"""
+
pass
@@ -204,4 +246,5 @@ class MessageOrderError(SSHException):
.. versionadded:: 3.4
"""
+
pass
diff --git a/paramiko/ssh_gss.py b/paramiko/ssh_gss.py
index 5956a062..ee49c34d 100644
--- a/paramiko/ssh_gss.py
+++ b/paramiko/ssh_gss.py
@@ -1,3 +1,24 @@
+# Copyright (C) 2013-2014 science + computing ag
+# Author: Sebastian Deiss <sebastian.deiss@t-online.de>
+#
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
+
"""
This module provides GSS-API / SSPI authentication as defined in :rfc:`4462`.
@@ -7,31 +28,48 @@ This module provides GSS-API / SSPI authentication as defined in :rfc:`4462`.
.. versionadded:: 1.15
"""
+
import struct
import os
import sys
+
+
+#: A boolean constraint that indicates if GSS-API / SSPI is available.
GSS_AUTH_AVAILABLE = True
+
+
+#: A tuple of the exception types used by the underlying GSSAPI implementation.
GSS_EXCEPTIONS = ()
+
+
+#: :var str _API: Constraint for the used API
_API = None
+
try:
import gssapi
- if hasattr(gssapi, '__title__') and gssapi.__title__ == 'python-gssapi':
- _API = 'MIT'
- GSS_EXCEPTIONS = gssapi.GSSException,
+
+ if hasattr(gssapi, "__title__") and gssapi.__title__ == "python-gssapi":
+ # old, unmaintained python-gssapi package
+ _API = "MIT" # keep this for compatibility
+ GSS_EXCEPTIONS = (gssapi.GSSException,)
else:
- _API = 'PYTHON-GSSAPI-NEW'
- GSS_EXCEPTIONS = (gssapi.exceptions.GeneralError, gssapi.raw.misc.
- GSSError)
+ _API = "PYTHON-GSSAPI-NEW"
+ GSS_EXCEPTIONS = (
+ gssapi.exceptions.GeneralError,
+ gssapi.raw.misc.GSSError,
+ )
except (ImportError, OSError):
try:
import pywintypes
import sspicon
import sspi
- _API = 'SSPI'
- GSS_EXCEPTIONS = pywintypes.error,
+
+ _API = "SSPI"
+ GSS_EXCEPTIONS = (pywintypes.error,)
except ImportError:
GSS_AUTH_AVAILABLE = False
_API = None
+
from paramiko.common import MSG_USERAUTH_REQUEST
from paramiko.ssh_exception import SSHException
from paramiko._version import __version_info__
@@ -59,7 +97,14 @@ def GSSAuth(auth_method, gss_deleg_creds=True):
If there is no supported API available,
``None`` will be returned.
"""
- pass
+ if _API == "MIT":
+ return _SSH_GSSAPI_OLD(auth_method, gss_deleg_creds)
+ elif _API == "PYTHON-GSSAPI-NEW":
+ return _SSH_GSSAPI_NEW(auth_method, gss_deleg_creds)
+ elif _API == "SSPI" and os.name == "nt":
+ return _SSH_SSPI(auth_method, gss_deleg_creds)
+ else:
+ raise ImportError("Unable to import a GSS-API / SSPI module!")
class _SSH_GSSAuth:
@@ -79,14 +124,18 @@ class _SSH_GSSAuth:
self._gss_host = None
self._username = None
self._session_id = None
- self._service = 'ssh-connection'
+ self._service = "ssh-connection"
"""
OpenSSH supports Kerberos V5 mechanism only for GSS-API authentication,
so we also support the krb5 mechanism only.
"""
- self._krb5_mech = '1.2.840.113554.1.2.2'
+ self._krb5_mech = "1.2.840.113554.1.2.2"
+
+ # client mode
self._gss_ctxt = None
self._gss_ctxt_status = False
+
+ # server mode
self._gss_srv_ctxt = None
self._gss_srv_ctxt_status = False
self.cc_file = None
@@ -99,7 +148,8 @@ class _SSH_GSSAuth:
:param str service: The desired SSH service
"""
- pass
+ if service.find("ssh-"):
+ self._service = service
def set_username(self, username):
"""
@@ -108,9 +158,9 @@ class _SSH_GSSAuth:
:param str username: The name of the user who attempts to login
"""
- pass
+ self._username = username
- def ssh_gss_oids(self, mode='client'):
+ def ssh_gss_oids(self, mode="client"):
"""
This method returns a single OID, because we only support the
Kerberos V5 mechanism.
@@ -122,7 +172,15 @@ class _SSH_GSSAuth:
:note: In server mode we just return the OID length and the DER encoded
OID.
"""
- pass
+ from pyasn1.type.univ import ObjectIdentifier
+ from pyasn1.codec.der import encoder
+
+ OIDs = self._make_uint32(1)
+ krb5_OID = encoder.encode(ObjectIdentifier(self._krb5_mech))
+ OID_len = self._make_uint32(len(krb5_OID))
+ if mode == "server":
+ return OID_len + krb5_OID
+ return OIDs + OID_len + krb5_OID
def ssh_check_mech(self, desired_mech):
"""
@@ -131,8 +189,15 @@ class _SSH_GSSAuth:
:param str desired_mech: The desired GSS-API mechanism of the client
:return: ``True`` if the given OID is supported, otherwise C{False}
"""
- pass
+ from pyasn1.codec.der import decoder
+ mech, __ = decoder.decode(desired_mech)
+ if mech.__str__() != self._krb5_mech:
+ return False
+ return True
+
+ # Internals
+ # -------------------------------------------------------------------------
def _make_uint32(self, integer):
"""
Create a 32 bit unsigned integer (The byte sequence of an integer).
@@ -140,7 +205,7 @@ class _SSH_GSSAuth:
:param int integer: The integer value to convert
:return: The byte sequence of an 32 bit integer
"""
- pass
+ return struct.pack("!I", integer)
def _ssh_build_mic(self, session_id, username, service, auth_method):
"""
@@ -159,7 +224,16 @@ class _SSH_GSSAuth:
string authentication-method
(gssapi-with-mic or gssapi-keyex)
"""
- pass
+ mic = self._make_uint32(len(session_id))
+ mic += session_id
+ mic += struct.pack("B", MSG_USERAUTH_REQUEST)
+ mic += self._make_uint32(len(username))
+ mic += username.encode()
+ mic += self._make_uint32(len(service))
+ mic += service.encode()
+ mic += self._make_uint32(len(auth_method))
+ mic += auth_method.encode()
+ return mic
class _SSH_GSSAPI_OLD(_SSH_GSSAuth):
@@ -177,15 +251,24 @@ class _SSH_GSSAPI_OLD(_SSH_GSSAuth):
:param bool gss_deleg_creds: Delegate client credentials or not
"""
_SSH_GSSAuth.__init__(self, auth_method, gss_deleg_creds)
+
if self._gss_deleg_creds:
- self._gss_flags = (gssapi.C_PROT_READY_FLAG, gssapi.
- C_INTEG_FLAG, gssapi.C_MUTUAL_FLAG, gssapi.C_DELEG_FLAG)
+ self._gss_flags = (
+ gssapi.C_PROT_READY_FLAG,
+ gssapi.C_INTEG_FLAG,
+ gssapi.C_MUTUAL_FLAG,
+ gssapi.C_DELEG_FLAG,
+ )
else:
- self._gss_flags = (gssapi.C_PROT_READY_FLAG, gssapi.
- C_INTEG_FLAG, gssapi.C_MUTUAL_FLAG)
+ self._gss_flags = (
+ gssapi.C_PROT_READY_FLAG,
+ gssapi.C_INTEG_FLAG,
+ gssapi.C_MUTUAL_FLAG,
+ )
- def ssh_init_sec_context(self, target, desired_mech=None, username=None,
- recv_token=None):
+ def ssh_init_sec_context(
+ self, target, desired_mech=None, username=None, recv_token=None
+ ):
"""
Initialize a GSS-API context.
@@ -201,7 +284,39 @@ class _SSH_GSSAPI_OLD(_SSH_GSSAuth):
:return: A ``String`` if the GSS-API has returned a token or
``None`` if no token was returned
"""
- pass
+ from pyasn1.codec.der import decoder
+
+ self._username = username
+ self._gss_host = target
+ targ_name = gssapi.Name(
+ "host@" + self._gss_host, gssapi.C_NT_HOSTBASED_SERVICE
+ )
+ ctx = gssapi.Context()
+ ctx.flags = self._gss_flags
+ if desired_mech is None:
+ krb5_mech = gssapi.OID.mech_from_string(self._krb5_mech)
+ else:
+ mech, __ = decoder.decode(desired_mech)
+ if mech.__str__() != self._krb5_mech:
+ raise SSHException("Unsupported mechanism OID.")
+ else:
+ krb5_mech = gssapi.OID.mech_from_string(self._krb5_mech)
+ token = None
+ try:
+ if recv_token is None:
+ self._gss_ctxt = gssapi.InitContext(
+ peer_name=targ_name,
+ mech_type=krb5_mech,
+ req_flags=ctx.flags,
+ )
+ token = self._gss_ctxt.step(token)
+ else:
+ token = self._gss_ctxt.step(recv_token)
+ except gssapi.GSSException:
+ message = "{} Target: {}".format(sys.exc_info()[1], self._gss_host)
+ raise gssapi.GSSException(message)
+ self._gss_ctxt_status = self._gss_ctxt.established
+ return token
def ssh_get_mic(self, session_id, gss_kex=False):
"""
@@ -216,7 +331,19 @@ class _SSH_GSSAPI_OLD(_SSH_GSSAuth):
Returns the MIC token from GSS-API with the SSH session ID as
message.
"""
- pass
+ self._session_id = session_id
+ if not gss_kex:
+ mic_field = self._ssh_build_mic(
+ self._session_id,
+ self._username,
+ self._service,
+ self._auth_method,
+ )
+ mic_token = self._gss_ctxt.get_mic(mic_field)
+ else:
+ # for key exchange with gssapi-keyex
+ mic_token = self._gss_srv_ctxt.get_mic(self._session_id)
+ return mic_token
def ssh_accept_sec_context(self, hostname, recv_token, username=None):
"""
@@ -229,7 +356,14 @@ class _SSH_GSSAPI_OLD(_SSH_GSSAuth):
:return: A ``String`` if the GSS-API has returned a token or ``None``
if no token was returned
"""
- pass
+ # hostname and username are not required for GSSAPI, but for SSPI
+ self._gss_host = hostname
+ self._username = username
+ if self._gss_srv_ctxt is None:
+ self._gss_srv_ctxt = gssapi.AcceptContext()
+ token = self._gss_srv_ctxt.step(recv_token)
+ self._gss_srv_ctxt_status = self._gss_srv_ctxt.established
+ return token
def ssh_check_mic(self, mic_token, session_id, username=None):
"""
@@ -241,7 +375,21 @@ class _SSH_GSSAPI_OLD(_SSH_GSSAuth):
:return: None if the MIC check was successful
:raises: ``gssapi.GSSException`` -- if the MIC check failed
"""
- pass
+ self._session_id = session_id
+ self._username = username
+ if self._username is not None:
+ # server mode
+ mic_field = self._ssh_build_mic(
+ self._session_id,
+ self._username,
+ self._service,
+ self._auth_method,
+ )
+ self._gss_srv_ctxt.verify_mic(mic_field, mic_token)
+ else:
+ # for key exchange with gssapi-keyex
+ # client mode
+ self._gss_ctxt.verify_mic(self._session_id, mic_token)
@property
def credentials_delegated(self):
@@ -250,7 +398,9 @@ class _SSH_GSSAPI_OLD(_SSH_GSSAuth):
:return: ``True`` if credentials are delegated, otherwise ``False``
"""
- pass
+ if self._gss_srv_ctxt.delegated_cred is not None:
+ return True
+ return False
def save_client_creds(self, client_token):
"""
@@ -263,10 +413,11 @@ class _SSH_GSSAPI_OLD(_SSH_GSSAuth):
``NotImplementedError`` -- Credential delegation is currently not
supported in server mode
"""
- pass
+ raise NotImplementedError
if __version_info__ < (2, 5):
+ # provide the old name for strict backward compatibility
_SSH_GSSAPI = _SSH_GSSAPI_OLD
@@ -285,17 +436,24 @@ class _SSH_GSSAPI_NEW(_SSH_GSSAuth):
:param bool gss_deleg_creds: Delegate client credentials or not
"""
_SSH_GSSAuth.__init__(self, auth_method, gss_deleg_creds)
+
if self._gss_deleg_creds:
- self._gss_flags = (gssapi.RequirementFlag.protection_ready,
- gssapi.RequirementFlag.integrity, gssapi.RequirementFlag.
- mutual_authentication, gssapi.RequirementFlag.delegate_to_peer)
+ self._gss_flags = (
+ gssapi.RequirementFlag.protection_ready,
+ gssapi.RequirementFlag.integrity,
+ gssapi.RequirementFlag.mutual_authentication,
+ gssapi.RequirementFlag.delegate_to_peer,
+ )
else:
- self._gss_flags = (gssapi.RequirementFlag.protection_ready,
- gssapi.RequirementFlag.integrity, gssapi.RequirementFlag.
- mutual_authentication)
+ self._gss_flags = (
+ gssapi.RequirementFlag.protection_ready,
+ gssapi.RequirementFlag.integrity,
+ gssapi.RequirementFlag.mutual_authentication,
+ )
- def ssh_init_sec_context(self, target, desired_mech=None, username=None,
- recv_token=None):
+ def ssh_init_sec_context(
+ self, target, desired_mech=None, username=None, recv_token=None
+ ):
"""
Initialize a GSS-API context.
@@ -312,7 +470,32 @@ class _SSH_GSSAPI_NEW(_SSH_GSSAuth):
:return: A ``String`` if the GSS-API has returned a token or ``None``
if no token was returned
"""
- pass
+ from pyasn1.codec.der import decoder
+
+ self._username = username
+ self._gss_host = target
+ targ_name = gssapi.Name(
+ "host@" + self._gss_host,
+ name_type=gssapi.NameType.hostbased_service,
+ )
+ if desired_mech is not None:
+ mech, __ = decoder.decode(desired_mech)
+ if mech.__str__() != self._krb5_mech:
+ raise SSHException("Unsupported mechanism OID.")
+ krb5_mech = gssapi.MechType.kerberos
+ token = None
+ if recv_token is None:
+ self._gss_ctxt = gssapi.SecurityContext(
+ name=targ_name,
+ flags=self._gss_flags,
+ mech=krb5_mech,
+ usage="initiate",
+ )
+ token = self._gss_ctxt.step(token)
+ else:
+ token = self._gss_ctxt.step(recv_token)
+ self._gss_ctxt_status = self._gss_ctxt.complete
+ return token
def ssh_get_mic(self, session_id, gss_kex=False):
"""
@@ -328,7 +511,19 @@ class _SSH_GSSAPI_NEW(_SSH_GSSAuth):
message.
:rtype: str
"""
- pass
+ self._session_id = session_id
+ if not gss_kex:
+ mic_field = self._ssh_build_mic(
+ self._session_id,
+ self._username,
+ self._service,
+ self._auth_method,
+ )
+ mic_token = self._gss_ctxt.get_signature(mic_field)
+ else:
+ # for key exchange with gssapi-keyex
+ mic_token = self._gss_srv_ctxt.get_signature(self._session_id)
+ return mic_token
def ssh_accept_sec_context(self, hostname, recv_token, username=None):
"""
@@ -341,7 +536,14 @@ class _SSH_GSSAPI_NEW(_SSH_GSSAuth):
:return: A ``String`` if the GSS-API has returned a token or ``None``
if no token was returned
"""
- pass
+ # hostname and username are not required for GSSAPI, but for SSPI
+ self._gss_host = hostname
+ self._username = username
+ if self._gss_srv_ctxt is None:
+ self._gss_srv_ctxt = gssapi.SecurityContext(usage="accept")
+ token = self._gss_srv_ctxt.step(recv_token)
+ self._gss_srv_ctxt_status = self._gss_srv_ctxt.complete
+ return token
def ssh_check_mic(self, mic_token, session_id, username=None):
"""
@@ -353,7 +555,21 @@ class _SSH_GSSAPI_NEW(_SSH_GSSAuth):
:return: None if the MIC check was successful
:raises: ``gssapi.exceptions.GSSError`` -- if the MIC check failed
"""
- pass
+ self._session_id = session_id
+ self._username = username
+ if self._username is not None:
+ # server mode
+ mic_field = self._ssh_build_mic(
+ self._session_id,
+ self._username,
+ self._service,
+ self._auth_method,
+ )
+ self._gss_srv_ctxt.verify_signature(mic_field, mic_token)
+ else:
+ # for key exchange with gssapi-keyex
+ # client mode
+ self._gss_ctxt.verify_signature(self._session_id, mic_token)
@property
def credentials_delegated(self):
@@ -363,7 +579,9 @@ class _SSH_GSSAPI_NEW(_SSH_GSSAuth):
:return: ``True`` if credentials are delegated, otherwise ``False``
:rtype: bool
"""
- pass
+ if self._gss_srv_ctxt.delegated_creds is not None:
+ return True
+ return False
def save_client_creds(self, client_token):
"""
@@ -375,7 +593,7 @@ class _SSH_GSSAPI_NEW(_SSH_GSSAuth):
:raises: ``NotImplementedError`` -- Credential delegation is currently
not supported in server mode
"""
- pass
+ raise NotImplementedError
class _SSH_SSPI(_SSH_GSSAuth):
@@ -392,15 +610,21 @@ class _SSH_SSPI(_SSH_GSSAuth):
:param bool gss_deleg_creds: Delegate client credentials or not
"""
_SSH_GSSAuth.__init__(self, auth_method, gss_deleg_creds)
+
if self._gss_deleg_creds:
- self._gss_flags = (sspicon.ISC_REQ_INTEGRITY | sspicon.
- ISC_REQ_MUTUAL_AUTH | sspicon.ISC_REQ_DELEGATE)
+ self._gss_flags = (
+ sspicon.ISC_REQ_INTEGRITY
+ | sspicon.ISC_REQ_MUTUAL_AUTH
+ | sspicon.ISC_REQ_DELEGATE
+ )
else:
- self._gss_flags = (sspicon.ISC_REQ_INTEGRITY | sspicon.
- ISC_REQ_MUTUAL_AUTH)
+ self._gss_flags = (
+ sspicon.ISC_REQ_INTEGRITY | sspicon.ISC_REQ_MUTUAL_AUTH
+ )
- def ssh_init_sec_context(self, target, desired_mech=None, username=None,
- recv_token=None):
+ def ssh_init_sec_context(
+ self, target, desired_mech=None, username=None, recv_token=None
+ ):
"""
Initialize a SSPI context.
@@ -416,7 +640,39 @@ class _SSH_SSPI(_SSH_GSSAuth):
:return: A ``String`` if the SSPI has returned a token or ``None`` if
no token was returned
"""
- pass
+ from pyasn1.codec.der import decoder
+
+ self._username = username
+ self._gss_host = target
+ error = 0
+ targ_name = "host/" + self._gss_host
+ if desired_mech is not None:
+ mech, __ = decoder.decode(desired_mech)
+ if mech.__str__() != self._krb5_mech:
+ raise SSHException("Unsupported mechanism OID.")
+ try:
+ if recv_token is None:
+ self._gss_ctxt = sspi.ClientAuth(
+ "Kerberos", scflags=self._gss_flags, targetspn=targ_name
+ )
+ error, token = self._gss_ctxt.authorize(recv_token)
+ token = token[0].Buffer
+ except pywintypes.error as e:
+ e.strerror += ", Target: {}".format(self._gss_host)
+ raise
+
+ if error == 0:
+ """
+ if the status is GSS_COMPLETE (error = 0) the context is fully
+ established an we can set _gss_ctxt_status to True.
+ """
+ self._gss_ctxt_status = True
+ token = None
+ """
+ You won't get another token if the context is fully established,
+ so i set token to None instead of ""
+ """
+ return token
def ssh_get_mic(self, session_id, gss_kex=False):
"""
@@ -431,7 +687,19 @@ class _SSH_SSPI(_SSH_GSSAuth):
Returns the MIC token from SSPI with the SSH session ID as
message.
"""
- pass
+ self._session_id = session_id
+ if not gss_kex:
+ mic_field = self._ssh_build_mic(
+ self._session_id,
+ self._username,
+ self._service,
+ self._auth_method,
+ )
+ mic_token = self._gss_ctxt.sign(mic_field)
+ else:
+ # for key exchange with gssapi-keyex
+ mic_token = self._gss_srv_ctxt.sign(self._session_id)
+ return mic_token
def ssh_accept_sec_context(self, hostname, username, recv_token):
"""
@@ -444,7 +712,16 @@ class _SSH_SSPI(_SSH_GSSAuth):
:return: A ``String`` if the SSPI has returned a token or ``None`` if
no token was returned
"""
- pass
+ self._gss_host = hostname
+ self._username = username
+ targ_name = "host/" + self._gss_host
+ self._gss_srv_ctxt = sspi.ServerAuth("Kerberos", spn=targ_name)
+ error, token = self._gss_srv_ctxt.authorize(recv_token)
+ token = token[0].Buffer
+ if error == 0:
+ self._gss_srv_ctxt_status = True
+ token = None
+ return token
def ssh_check_mic(self, mic_token, session_id, username=None):
"""
@@ -456,7 +733,25 @@ class _SSH_SSPI(_SSH_GSSAuth):
:return: None if the MIC check was successful
:raises: ``sspi.error`` -- if the MIC check failed
"""
- pass
+ self._session_id = session_id
+ self._username = username
+ if username is not None:
+ # server mode
+ mic_field = self._ssh_build_mic(
+ self._session_id,
+ self._username,
+ self._service,
+ self._auth_method,
+ )
+ # Verifies data and its signature. If verification fails, an
+ # sspi.error will be raised.
+ self._gss_srv_ctxt.verify(mic_field, mic_token)
+ else:
+ # for key exchange with gssapi-keyex
+ # client mode
+ # Verifies data and its signature. If verification fails, an
+ # sspi.error will be raised.
+ self._gss_ctxt.verify(self._session_id, mic_token)
@property
def credentials_delegated(self):
@@ -465,7 +760,9 @@ class _SSH_SSPI(_SSH_GSSAuth):
:return: ``True`` if credentials are delegated, otherwise ``False``
"""
- pass
+ return self._gss_flags & sspicon.ISC_REQ_DELEGATE and (
+ self._gss_srv_ctxt_status or self._gss_flags
+ )
def save_client_creds(self, client_token):
"""
@@ -478,4 +775,4 @@ class _SSH_SSPI(_SSH_GSSAuth):
``NotImplementedError`` -- Credential delegation is currently not
supported in server mode
"""
- pass
+ raise NotImplementedError
diff --git a/paramiko/transport.py b/paramiko/transport.py
index a4f0e92e..ecd8c7bc 100644
--- a/paramiko/transport.py
+++ b/paramiko/transport.py
@@ -1,6 +1,26 @@
+# Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
+# Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
"""
Core protocol implementation
"""
+
import os
import socket
import sys
@@ -8,14 +28,67 @@ import threading
import time
import weakref
from hashlib import md5, sha1, sha256, sha512
+
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import algorithms, Cipher, modes
+
import paramiko
from paramiko import util
from paramiko.auth_handler import AuthHandler, AuthOnlyHandler
from paramiko.ssh_gss import GSSAuth
from paramiko.channel import Channel
-from paramiko.common import xffffffff, cMSG_CHANNEL_OPEN, cMSG_IGNORE, cMSG_GLOBAL_REQUEST, DEBUG, MSG_KEXINIT, MSG_IGNORE, MSG_DISCONNECT, MSG_DEBUG, ERROR, WARNING, cMSG_UNIMPLEMENTED, INFO, cMSG_KEXINIT, cMSG_NEWKEYS, MSG_NEWKEYS, cMSG_REQUEST_SUCCESS, cMSG_REQUEST_FAILURE, CONNECTION_FAILED_CODE, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, OPEN_SUCCEEDED, cMSG_CHANNEL_OPEN_FAILURE, cMSG_CHANNEL_OPEN_SUCCESS, MSG_GLOBAL_REQUEST, MSG_REQUEST_SUCCESS, MSG_REQUEST_FAILURE, cMSG_SERVICE_REQUEST, MSG_SERVICE_ACCEPT, MSG_CHANNEL_OPEN_SUCCESS, MSG_CHANNEL_OPEN_FAILURE, MSG_CHANNEL_OPEN, MSG_CHANNEL_SUCCESS, MSG_CHANNEL_FAILURE, MSG_CHANNEL_DATA, MSG_CHANNEL_EXTENDED_DATA, MSG_CHANNEL_WINDOW_ADJUST, MSG_CHANNEL_REQUEST, MSG_CHANNEL_EOF, MSG_CHANNEL_CLOSE, MIN_WINDOW_SIZE, MIN_PACKET_SIZE, MAX_WINDOW_SIZE, DEFAULT_WINDOW_SIZE, DEFAULT_MAX_PACKET_SIZE, HIGHEST_USERAUTH_MESSAGE_ID, MSG_UNIMPLEMENTED, MSG_NAMES, MSG_EXT_INFO, cMSG_EXT_INFO, byte_ord
+from paramiko.common import (
+ xffffffff,
+ cMSG_CHANNEL_OPEN,
+ cMSG_IGNORE,
+ cMSG_GLOBAL_REQUEST,
+ DEBUG,
+ MSG_KEXINIT,
+ MSG_IGNORE,
+ MSG_DISCONNECT,
+ MSG_DEBUG,
+ ERROR,
+ WARNING,
+ cMSG_UNIMPLEMENTED,
+ INFO,
+ cMSG_KEXINIT,
+ cMSG_NEWKEYS,
+ MSG_NEWKEYS,
+ cMSG_REQUEST_SUCCESS,
+ cMSG_REQUEST_FAILURE,
+ CONNECTION_FAILED_CODE,
+ OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED,
+ OPEN_SUCCEEDED,
+ cMSG_CHANNEL_OPEN_FAILURE,
+ cMSG_CHANNEL_OPEN_SUCCESS,
+ MSG_GLOBAL_REQUEST,
+ MSG_REQUEST_SUCCESS,
+ MSG_REQUEST_FAILURE,
+ cMSG_SERVICE_REQUEST,
+ MSG_SERVICE_ACCEPT,
+ MSG_CHANNEL_OPEN_SUCCESS,
+ MSG_CHANNEL_OPEN_FAILURE,
+ MSG_CHANNEL_OPEN,
+ MSG_CHANNEL_SUCCESS,
+ MSG_CHANNEL_FAILURE,
+ MSG_CHANNEL_DATA,
+ MSG_CHANNEL_EXTENDED_DATA,
+ MSG_CHANNEL_WINDOW_ADJUST,
+ MSG_CHANNEL_REQUEST,
+ MSG_CHANNEL_EOF,
+ MSG_CHANNEL_CLOSE,
+ MIN_WINDOW_SIZE,
+ MIN_PACKET_SIZE,
+ MAX_WINDOW_SIZE,
+ DEFAULT_WINDOW_SIZE,
+ DEFAULT_MAX_PACKET_SIZE,
+ HIGHEST_USERAUTH_MESSAGE_ID,
+ MSG_UNIMPLEMENTED,
+ MSG_NAMES,
+ MSG_EXT_INFO,
+ cMSG_EXT_INFO,
+ byte_ord,
+)
from paramiko.compress import ZlibCompressor, ZlibDecompressor
from paramiko.dsskey import DSSKey
from paramiko.ed25519key import Ed25519Key
@@ -33,14 +106,46 @@ from paramiko.rsakey import RSAKey
from paramiko.ecdsakey import ECDSAKey
from paramiko.server import ServerInterface
from paramiko.sftp_client import SFTPClient
-from paramiko.ssh_exception import BadAuthenticationType, ChannelException, IncompatiblePeer, MessageOrderError, ProxyCommandFailure, SSHException
-from paramiko.util import ClosingContextManager, clamp_value, b
+from paramiko.ssh_exception import (
+ BadAuthenticationType,
+ ChannelException,
+ IncompatiblePeer,
+ MessageOrderError,
+ ProxyCommandFailure,
+ SSHException,
+)
+from paramiko.util import (
+ ClosingContextManager,
+ clamp_value,
+ b,
+)
+
+
+# TripleDES is moving from `cryptography.hazmat.primitives.ciphers.algorithms`
+# in cryptography>=43.0.0 to `cryptography.hazmat.decrepit.ciphers.algorithms`
+# It will be removed from `cryptography.hazmat.primitives.ciphers.algorithms`
+# in cryptography==48.0.0.
+#
+# Source References:
+# - https://github.com/pyca/cryptography/commit/722a6393e61b3ac
+# - https://github.com/pyca/cryptography/pull/11407/files
try:
from cryptography.hazmat.decrepit.ciphers.algorithms import TripleDES
except ImportError:
from cryptography.hazmat.primitives.ciphers.algorithms import TripleDES
+
+
+# for thread cleanup
_active_threads = []
+
+
+def _join_lingering_threads():
+ for thr in _active_threads:
+ thr.stop_thread()
+
+
import atexit
+
atexit.register(_join_lingering_threads)
@@ -54,85 +159,199 @@ class Transport(threading.Thread, ClosingContextManager):
Instances of this class may be used as context managers.
"""
+
_ENCRYPT = object()
_DECRYPT = object()
- _PROTO_ID = '2.0'
- _CLIENT_ID = 'paramiko_{}'.format(paramiko.__version__)
- _preferred_ciphers = ('aes128-ctr', 'aes192-ctr', 'aes256-ctr',
- 'aes128-cbc', 'aes192-cbc', 'aes256-cbc', '3des-cbc')
- _preferred_macs = ('hmac-sha2-256', 'hmac-sha2-512',
- 'hmac-sha2-256-etm@openssh.com', 'hmac-sha2-512-etm@openssh.com',
- 'hmac-sha1', 'hmac-md5', 'hmac-sha1-96', 'hmac-md5-96')
- _preferred_keys = ('ssh-ed25519', 'ecdsa-sha2-nistp256',
- 'ecdsa-sha2-nistp384', 'ecdsa-sha2-nistp521', 'rsa-sha2-512',
- 'rsa-sha2-256', 'ssh-rsa', 'ssh-dss')
- _preferred_pubkeys = ('ssh-ed25519', 'ecdsa-sha2-nistp256',
- 'ecdsa-sha2-nistp384', 'ecdsa-sha2-nistp521', 'rsa-sha2-512',
- 'rsa-sha2-256', 'ssh-rsa', 'ssh-dss')
- _preferred_kex = ('ecdh-sha2-nistp256', 'ecdh-sha2-nistp384',
- 'ecdh-sha2-nistp521', 'diffie-hellman-group16-sha512',
- 'diffie-hellman-group-exchange-sha256',
- 'diffie-hellman-group14-sha256',
- 'diffie-hellman-group-exchange-sha1', 'diffie-hellman-group14-sha1',
- 'diffie-hellman-group1-sha1')
+
+ _PROTO_ID = "2.0"
+ _CLIENT_ID = "paramiko_{}".format(paramiko.__version__)
+
+ # These tuples of algorithm identifiers are in preference order; do not
+ # reorder without reason!
+ # NOTE: if you need to modify these, we suggest leveraging the
+ # `disabled_algorithms` constructor argument (also available in SSHClient)
+ # instead of monkeypatching or subclassing.
+ _preferred_ciphers = (
+ "aes128-ctr",
+ "aes192-ctr",
+ "aes256-ctr",
+ "aes128-cbc",
+ "aes192-cbc",
+ "aes256-cbc",
+ "3des-cbc",
+ )
+ _preferred_macs = (
+ "hmac-sha2-256",
+ "hmac-sha2-512",
+ "hmac-sha2-256-etm@openssh.com",
+ "hmac-sha2-512-etm@openssh.com",
+ "hmac-sha1",
+ "hmac-md5",
+ "hmac-sha1-96",
+ "hmac-md5-96",
+ )
+ # ~= HostKeyAlgorithms in OpenSSH land
+ _preferred_keys = (
+ "ssh-ed25519",
+ "ecdsa-sha2-nistp256",
+ "ecdsa-sha2-nistp384",
+ "ecdsa-sha2-nistp521",
+ "rsa-sha2-512",
+ "rsa-sha2-256",
+ "ssh-rsa",
+ "ssh-dss",
+ )
+ # ~= PubKeyAcceptedAlgorithms
+ _preferred_pubkeys = (
+ "ssh-ed25519",
+ "ecdsa-sha2-nistp256",
+ "ecdsa-sha2-nistp384",
+ "ecdsa-sha2-nistp521",
+ "rsa-sha2-512",
+ "rsa-sha2-256",
+ "ssh-rsa",
+ "ssh-dss",
+ )
+ _preferred_kex = (
+ "ecdh-sha2-nistp256",
+ "ecdh-sha2-nistp384",
+ "ecdh-sha2-nistp521",
+ "diffie-hellman-group16-sha512",
+ "diffie-hellman-group-exchange-sha256",
+ "diffie-hellman-group14-sha256",
+ "diffie-hellman-group-exchange-sha1",
+ "diffie-hellman-group14-sha1",
+ "diffie-hellman-group1-sha1",
+ )
if KexCurve25519.is_available():
- _preferred_kex = ('curve25519-sha256@libssh.org',) + _preferred_kex
- _preferred_gsskex = ('gss-gex-sha1-toWM5Slw5Ew8Mqkay+al2g==',
- 'gss-group14-sha1-toWM5Slw5Ew8Mqkay+al2g==',
- 'gss-group1-sha1-toWM5Slw5Ew8Mqkay+al2g==')
- _preferred_compression = 'none',
- _cipher_info = {'aes128-ctr': {'class': algorithms.AES, 'mode': modes.
- CTR, 'block-size': 16, 'key-size': 16}, 'aes192-ctr': {'class':
- algorithms.AES, 'mode': modes.CTR, 'block-size': 16, 'key-size': 24
- }, 'aes256-ctr': {'class': algorithms.AES, 'mode': modes.CTR,
- 'block-size': 16, 'key-size': 32}, 'aes128-cbc': {'class':
- algorithms.AES, 'mode': modes.CBC, 'block-size': 16, 'key-size': 16
- }, 'aes192-cbc': {'class': algorithms.AES, 'mode': modes.CBC,
- 'block-size': 16, 'key-size': 24}, 'aes256-cbc': {'class':
- algorithms.AES, 'mode': modes.CBC, 'block-size': 16, 'key-size': 32
- }, '3des-cbc': {'class': TripleDES, 'mode': modes.CBC, 'block-size':
- 8, 'key-size': 24}}
- _mac_info = {'hmac-sha1': {'class': sha1, 'size': 20}, 'hmac-sha1-96':
- {'class': sha1, 'size': 12}, 'hmac-sha2-256': {'class': sha256,
- 'size': 32}, 'hmac-sha2-256-etm@openssh.com': {'class': sha256,
- 'size': 32}, 'hmac-sha2-512': {'class': sha512, 'size': 64},
- 'hmac-sha2-512-etm@openssh.com': {'class': sha512, 'size': 64},
- 'hmac-md5': {'class': md5, 'size': 16}, 'hmac-md5-96': {'class':
- md5, 'size': 12}}
- _key_info = {'ssh-rsa': RSAKey, 'ssh-rsa-cert-v01@openssh.com': RSAKey,
- 'rsa-sha2-256': RSAKey, 'rsa-sha2-256-cert-v01@openssh.com': RSAKey,
- 'rsa-sha2-512': RSAKey, 'rsa-sha2-512-cert-v01@openssh.com': RSAKey,
- 'ssh-dss': DSSKey, 'ssh-dss-cert-v01@openssh.com': DSSKey,
- 'ecdsa-sha2-nistp256': ECDSAKey,
- 'ecdsa-sha2-nistp256-cert-v01@openssh.com': ECDSAKey,
- 'ecdsa-sha2-nistp384': ECDSAKey,
- 'ecdsa-sha2-nistp384-cert-v01@openssh.com': ECDSAKey,
- 'ecdsa-sha2-nistp521': ECDSAKey,
- 'ecdsa-sha2-nistp521-cert-v01@openssh.com': ECDSAKey, 'ssh-ed25519':
- Ed25519Key, 'ssh-ed25519-cert-v01@openssh.com': Ed25519Key}
- _kex_info = {'diffie-hellman-group1-sha1': KexGroup1,
- 'diffie-hellman-group14-sha1': KexGroup14,
- 'diffie-hellman-group-exchange-sha1': KexGex,
- 'diffie-hellman-group-exchange-sha256': KexGexSHA256,
- 'diffie-hellman-group14-sha256': KexGroup14SHA256,
- 'diffie-hellman-group16-sha512': KexGroup16SHA512,
- 'gss-group1-sha1-toWM5Slw5Ew8Mqkay+al2g==': KexGSSGroup1,
- 'gss-group14-sha1-toWM5Slw5Ew8Mqkay+al2g==': KexGSSGroup14,
- 'gss-gex-sha1-toWM5Slw5Ew8Mqkay+al2g==': KexGSSGex,
- 'ecdh-sha2-nistp256': KexNistp256, 'ecdh-sha2-nistp384':
- KexNistp384, 'ecdh-sha2-nistp521': KexNistp521}
+ _preferred_kex = ("curve25519-sha256@libssh.org",) + _preferred_kex
+ _preferred_gsskex = (
+ "gss-gex-sha1-toWM5Slw5Ew8Mqkay+al2g==",
+ "gss-group14-sha1-toWM5Slw5Ew8Mqkay+al2g==",
+ "gss-group1-sha1-toWM5Slw5Ew8Mqkay+al2g==",
+ )
+ _preferred_compression = ("none",)
+
+ _cipher_info = {
+ "aes128-ctr": {
+ "class": algorithms.AES,
+ "mode": modes.CTR,
+ "block-size": 16,
+ "key-size": 16,
+ },
+ "aes192-ctr": {
+ "class": algorithms.AES,
+ "mode": modes.CTR,
+ "block-size": 16,
+ "key-size": 24,
+ },
+ "aes256-ctr": {
+ "class": algorithms.AES,
+ "mode": modes.CTR,
+ "block-size": 16,
+ "key-size": 32,
+ },
+ "aes128-cbc": {
+ "class": algorithms.AES,
+ "mode": modes.CBC,
+ "block-size": 16,
+ "key-size": 16,
+ },
+ "aes192-cbc": {
+ "class": algorithms.AES,
+ "mode": modes.CBC,
+ "block-size": 16,
+ "key-size": 24,
+ },
+ "aes256-cbc": {
+ "class": algorithms.AES,
+ "mode": modes.CBC,
+ "block-size": 16,
+ "key-size": 32,
+ },
+ "3des-cbc": {
+ "class": TripleDES,
+ "mode": modes.CBC,
+ "block-size": 8,
+ "key-size": 24,
+ },
+ }
+
+ _mac_info = {
+ "hmac-sha1": {"class": sha1, "size": 20},
+ "hmac-sha1-96": {"class": sha1, "size": 12},
+ "hmac-sha2-256": {"class": sha256, "size": 32},
+ "hmac-sha2-256-etm@openssh.com": {"class": sha256, "size": 32},
+ "hmac-sha2-512": {"class": sha512, "size": 64},
+ "hmac-sha2-512-etm@openssh.com": {"class": sha512, "size": 64},
+ "hmac-md5": {"class": md5, "size": 16},
+ "hmac-md5-96": {"class": md5, "size": 12},
+ }
+
+ _key_info = {
+ # TODO: at some point we will want to drop this as it's no longer
+ # considered secure due to using SHA-1 for signatures. OpenSSH 8.8 no
+ # longer supports it. Question becomes at what point do we want to
+ # prevent users with older setups from using this?
+ "ssh-rsa": RSAKey,
+ "ssh-rsa-cert-v01@openssh.com": RSAKey,
+ "rsa-sha2-256": RSAKey,
+ "rsa-sha2-256-cert-v01@openssh.com": RSAKey,
+ "rsa-sha2-512": RSAKey,
+ "rsa-sha2-512-cert-v01@openssh.com": RSAKey,
+ "ssh-dss": DSSKey,
+ "ssh-dss-cert-v01@openssh.com": DSSKey,
+ "ecdsa-sha2-nistp256": ECDSAKey,
+ "ecdsa-sha2-nistp256-cert-v01@openssh.com": ECDSAKey,
+ "ecdsa-sha2-nistp384": ECDSAKey,
+ "ecdsa-sha2-nistp384-cert-v01@openssh.com": ECDSAKey,
+ "ecdsa-sha2-nistp521": ECDSAKey,
+ "ecdsa-sha2-nistp521-cert-v01@openssh.com": ECDSAKey,
+ "ssh-ed25519": Ed25519Key,
+ "ssh-ed25519-cert-v01@openssh.com": Ed25519Key,
+ }
+
+ _kex_info = {
+ "diffie-hellman-group1-sha1": KexGroup1,
+ "diffie-hellman-group14-sha1": KexGroup14,
+ "diffie-hellman-group-exchange-sha1": KexGex,
+ "diffie-hellman-group-exchange-sha256": KexGexSHA256,
+ "diffie-hellman-group14-sha256": KexGroup14SHA256,
+ "diffie-hellman-group16-sha512": KexGroup16SHA512,
+ "gss-group1-sha1-toWM5Slw5Ew8Mqkay+al2g==": KexGSSGroup1,
+ "gss-group14-sha1-toWM5Slw5Ew8Mqkay+al2g==": KexGSSGroup14,
+ "gss-gex-sha1-toWM5Slw5Ew8Mqkay+al2g==": KexGSSGex,
+ "ecdh-sha2-nistp256": KexNistp256,
+ "ecdh-sha2-nistp384": KexNistp384,
+ "ecdh-sha2-nistp521": KexNistp521,
+ }
if KexCurve25519.is_available():
- _kex_info['curve25519-sha256@libssh.org'] = KexCurve25519
- _compression_info = {'zlib@openssh.com': (ZlibCompressor,
- ZlibDecompressor), 'zlib': (ZlibCompressor, ZlibDecompressor),
- 'none': (None, None)}
+ _kex_info["curve25519-sha256@libssh.org"] = KexCurve25519
+
+ _compression_info = {
+ # zlib@openssh.com is just zlib, but only turned on after a successful
+ # authentication. openssh servers may only offer this type because
+ # they've had troubles with security holes in zlib in the past.
+ "zlib@openssh.com": (ZlibCompressor, ZlibDecompressor),
+ "zlib": (ZlibCompressor, ZlibDecompressor),
+ "none": (None, None),
+ }
+
_modulus_pack = None
_active_check_timeout = 0.1
- def __init__(self, sock, default_window_size=DEFAULT_WINDOW_SIZE,
- default_max_packet_size=DEFAULT_MAX_PACKET_SIZE, gss_kex=False,
- gss_deleg_creds=True, disabled_algorithms=None, server_sig_algs=
- True, strict_kex=True, packetizer_class=None):
+ def __init__(
+ self,
+ sock,
+ default_window_size=DEFAULT_WINDOW_SIZE,
+ default_max_packet_size=DEFAULT_MAX_PACKET_SIZE,
+ gss_kex=False,
+ gss_deleg_creds=True,
+ disabled_algorithms=None,
+ server_sig_algs=True,
+ strict_kex=True,
+ packetizer_class=None,
+ ):
"""
Create a new SSH session over an existing socket, or socket-like
object. This only creates the `.Transport` object; it doesn't begin
@@ -225,22 +444,29 @@ class Transport(threading.Thread, ClosingContextManager):
self.server_extensions = {}
self.advertise_strict_kex = strict_kex
self.agreed_on_strict_kex = False
+
+ # TODO: these two overrides on sock's type should go away sometime, too
+ # many ways to do it!
if isinstance(sock, str):
- hl = sock.split(':', 1)
+ # convert "host:port" into (host, port)
+ hl = sock.split(":", 1)
self.hostname = hl[0]
if len(hl) == 1:
- sock = hl[0], 22
+ sock = (hl[0], 22)
else:
- sock = hl[0], int(hl[1])
+ sock = (hl[0], int(hl[1]))
if type(sock) is tuple:
+ # connect to the given (host, port)
hostname, port = sock
self.hostname = hostname
- reason = 'No suitable address family'
- addrinfos = socket.getaddrinfo(hostname, port, socket.AF_UNSPEC,
- socket.SOCK_STREAM)
+ reason = "No suitable address family"
+ addrinfos = socket.getaddrinfo(
+ hostname, port, socket.AF_UNSPEC, socket.SOCK_STREAM
+ )
for family, socktype, proto, canonname, sockaddr in addrinfos:
if socktype == socket.SOCK_STREAM:
af = family
+ # addr = sockaddr
sock = socket.socket(af, socket.SOCK_STREAM)
try:
sock.connect((hostname, port))
@@ -249,97 +475,171 @@ class Transport(threading.Thread, ClosingContextManager):
else:
break
else:
- raise SSHException('Unable to connect to {}: {}'.format(
- hostname, reason))
+ raise SSHException(
+ "Unable to connect to {}: {}".format(hostname, reason)
+ )
+ # okay, normal socket-ish flow here...
threading.Thread.__init__(self)
self.daemon = True
self.sock = sock
+ # we set the timeout so we can check self.active periodically to
+ # see if we should bail. socket.timeout exception is never propagated.
self.sock.settimeout(self._active_check_timeout)
+
+ # negotiated crypto parameters
self.packetizer = (packetizer_class or Packetizer)(sock)
- self.local_version = 'SSH-' + self._PROTO_ID + '-' + self._CLIENT_ID
- self.remote_version = ''
- self.local_cipher = self.remote_cipher = ''
+ self.local_version = "SSH-" + self._PROTO_ID + "-" + self._CLIENT_ID
+ self.remote_version = ""
+ self.local_cipher = self.remote_cipher = ""
self.local_kex_init = self.remote_kex_init = None
self.local_mac = self.remote_mac = None
self.local_compression = self.remote_compression = None
self.session_id = None
self.host_key_type = None
self.host_key = None
+
+ # GSS-API / SSPI Key Exchange
self.use_gss_kex = gss_kex
+ # This will be set to True if GSS-API Key Exchange was performed
self.gss_kex_used = False
self.kexgss_ctxt = None
self.gss_host = None
if self.use_gss_kex:
- self.kexgss_ctxt = GSSAuth('gssapi-keyex', gss_deleg_creds)
+ self.kexgss_ctxt = GSSAuth("gssapi-keyex", gss_deleg_creds)
self._preferred_kex = self._preferred_gsskex + self._preferred_kex
+
+ # state used during negotiation
self.kex_engine = None
self.H = None
self.K = None
+
self.initial_kex_done = False
self.in_kex = False
self.authenticated = False
self._expected_packet = tuple()
+ # synchronization (always higher level than write_lock)
self.lock = threading.Lock()
+
+ # tracking open channels
self._channels = ChannelMap()
- self.channel_events = {}
- self.channels_seen = {}
+ self.channel_events = {} # (id -> Event)
+ self.channels_seen = {} # (id -> True)
self._channel_counter = 0
self.default_max_packet_size = default_max_packet_size
self.default_window_size = default_window_size
self._forward_agent_handler = None
self._x11_handler = None
self._tcp_handler = None
+
self.saved_exception = None
self.clear_to_send = threading.Event()
self.clear_to_send_lock = threading.Lock()
self.clear_to_send_timeout = 30.0
- self.log_name = 'paramiko.transport'
+ self.log_name = "paramiko.transport"
self.logger = util.get_logger(self.log_name)
self.packetizer.set_log(self.logger)
self.auth_handler = None
+ # response Message from an arbitrary global request
self.global_response = None
+ # user-defined event callbacks
self.completion_event = None
+ # how long (seconds) to wait for the SSH banner
self.banner_timeout = 15
+ # how long (seconds) to wait for the handshake to finish after SSH
+ # banner sent.
self.handshake_timeout = 15
+ # how long (seconds) to wait for the auth response.
self.auth_timeout = 30
+ # how long (seconds) to wait for opening a channel
self.channel_timeout = 60 * 60
self.disabled_algorithms = disabled_algorithms or {}
self.server_sig_algs = server_sig_algs
+
+ # server mode:
self.server_mode = False
self.server_object = None
self.server_key_dict = {}
self.server_accepts = []
self.server_accept_cv = threading.Condition(self.lock)
self.subsystem_table = {}
- self._handler_table = {MSG_EXT_INFO: self._parse_ext_info,
- MSG_NEWKEYS: self._parse_newkeys, MSG_GLOBAL_REQUEST: self.
- _parse_global_request, MSG_REQUEST_SUCCESS: self.
- _parse_request_success, MSG_REQUEST_FAILURE: self.
- _parse_request_failure, MSG_CHANNEL_OPEN_SUCCESS: self.
- _parse_channel_open_success, MSG_CHANNEL_OPEN_FAILURE: self.
- _parse_channel_open_failure, MSG_CHANNEL_OPEN: self.
- _parse_channel_open, MSG_KEXINIT: self._negotiate_keys}
+
+ # Handler table, now set at init time for easier per-instance
+ # manipulation and subclass twiddling.
+ self._handler_table = {
+ MSG_EXT_INFO: self._parse_ext_info,
+ MSG_NEWKEYS: self._parse_newkeys,
+ MSG_GLOBAL_REQUEST: self._parse_global_request,
+ MSG_REQUEST_SUCCESS: self._parse_request_success,
+ MSG_REQUEST_FAILURE: self._parse_request_failure,
+ MSG_CHANNEL_OPEN_SUCCESS: self._parse_channel_open_success,
+ MSG_CHANNEL_OPEN_FAILURE: self._parse_channel_open_failure,
+ MSG_CHANNEL_OPEN: self._parse_channel_open,
+ MSG_KEXINIT: self._negotiate_keys,
+ }
+
+ def _filter_algorithm(self, type_):
+ default = getattr(self, "_preferred_{}".format(type_))
+ return tuple(
+ x
+ for x in default
+ if x not in self.disabled_algorithms.get(type_, [])
+ )
+
+ @property
+ def preferred_ciphers(self):
+ return self._filter_algorithm("ciphers")
+
+ @property
+ def preferred_macs(self):
+ return self._filter_algorithm("macs")
+
+ @property
+ def preferred_keys(self):
+ # Interleave cert variants here; resistant to various background
+ # overwriting of _preferred_keys, and necessary as hostkeys can't use
+ # the logic pubkey auth does re: injecting/checking for certs at
+ # runtime
+ filtered = self._filter_algorithm("keys")
+ return tuple(
+ filtered
+ + tuple("{}-cert-v01@openssh.com".format(x) for x in filtered)
+ )
+
+ @property
+ def preferred_pubkeys(self):
+ return self._filter_algorithm("pubkeys")
+
+ @property
+ def preferred_kex(self):
+ return self._filter_algorithm("kex")
+
+ @property
+ def preferred_compression(self):
+ return self._filter_algorithm("compression")
def __repr__(self):
"""
Returns a string representation of this object, for debugging.
"""
id_ = hex(id(self) & xffffffff)
- out = '<paramiko.Transport at {}'.format(id_)
+ out = "<paramiko.Transport at {}".format(id_)
if not self.active:
- out += ' (unconnected)'
+ out += " (unconnected)"
else:
- if self.local_cipher != '':
- out += ' (cipher {}, {:d} bits)'.format(self.local_cipher,
- self._cipher_info[self.local_cipher]['key-size'] * 8)
+ if self.local_cipher != "":
+ out += " (cipher {}, {:d} bits)".format(
+ self.local_cipher,
+ self._cipher_info[self.local_cipher]["key-size"] * 8,
+ )
if self.is_authenticated():
- out += ' (active; {} open channel(s))'.format(len(self.
- _channels))
+ out += " (active; {} open channel(s))".format(
+ len(self._channels)
+ )
elif self.initial_kex_done:
- out += ' (connected; awaiting auth)'
+ out += " (connected; awaiting auth)"
else:
- out += ' (connecting)'
- out += '>'
+ out += " (connecting)"
+ out += ">"
return out
def atfork(self):
@@ -352,7 +652,8 @@ class Transport(threading.Thread, ClosingContextManager):
.. versionadded:: 1.5.3
"""
- pass
+ self.sock.close()
+ self.close()
def get_security_options(self):
"""
@@ -361,7 +662,7 @@ class Transport(threading.Thread, ClosingContextManager):
digest/hash operations, public keys, and key exchanges) and the order
of preference for them.
"""
- pass
+ return SecurityOptions(self)
def set_gss_host(self, gss_host, trust_dns=True, gssapi_requested=True):
"""
@@ -384,7 +685,18 @@ class Transport(threading.Thread, ClosingContextManager):
(Defaults to True due to backwards compatibility.)
:returns: ``None``.
"""
- pass
+ # No GSSAPI in play == nothing to do
+ if not gssapi_requested:
+ return
+ # Obtain the correct host first - did user request a GSS-specific name
+ # to use that is distinct from the actual SSH target hostname?
+ if gss_host is None:
+ gss_host = self.hostname
+ # Finally, canonicalize via DNS if DNS is trusted.
+ if trust_dns and gss_host is not None:
+ gss_host = socket.getfqdn(gss_host)
+ # And set attribute for reference later.
+ self.gss_host = gss_host
def start_client(self, event=None, timeout=None):
"""
@@ -421,7 +733,28 @@ class Transport(threading.Thread, ClosingContextManager):
`.SSHException` -- if negotiation fails (and no ``event`` was
passed in)
"""
- pass
+ self.active = True
+ if event is not None:
+ # async, return immediately and let the app poll for completion
+ self.completion_event = event
+ self.start()
+ return
+
+ # synchronous, wait for a result
+ self.completion_event = event = threading.Event()
+ self.start()
+ max_time = time.time() + timeout if timeout is not None else None
+ while True:
+ event.wait(0.1)
+ if not self.active:
+ e = self.get_exception()
+ if e is not None:
+ raise e
+ raise SSHException("Negotiation failed.")
+ if event.is_set() or (
+ timeout is not None and time.time() >= max_time
+ ):
+ break
def start_server(self, event=None, server=None):
"""
@@ -465,7 +798,29 @@ class Transport(threading.Thread, ClosingContextManager):
`.SSHException` -- if negotiation fails (and no ``event`` was
passed in)
"""
- pass
+ if server is None:
+ server = ServerInterface()
+ self.server_mode = True
+ self.server_object = server
+ self.active = True
+ if event is not None:
+ # async, return immediately and let the app poll for completion
+ self.completion_event = event
+ self.start()
+ return
+
+ # synchronous, wait for a result
+ self.completion_event = event = threading.Event()
+ self.start()
+ while True:
+ event.wait(0.1)
+ if not self.active:
+ e = self.get_exception()
+ if e is not None:
+ raise e
+ raise SSHException("Negotiation failed.")
+ if event.is_set():
+ break
def add_server_key(self, key):
"""
@@ -479,7 +834,13 @@ class Transport(threading.Thread, ClosingContextManager):
:param .PKey key:
the host key to add, usually an `.RSAKey` or `.DSSKey`.
"""
- pass
+ self.server_key_dict[key.get_name()] = key
+ # Handle SHA-2 extensions for RSA by ensuring that lookups into
+ # self.server_key_dict will yield this key for any of the algorithm
+ # names.
+ if isinstance(key, RSAKey):
+ self.server_key_dict["rsa-sha2-256"] = key
+ self.server_key_dict["rsa-sha2-512"] = key
def get_server_key(self):
"""
@@ -496,7 +857,11 @@ class Transport(threading.Thread, ClosingContextManager):
host key (`.PKey`) of the type negotiated by the client, or
``None``.
"""
- pass
+ try:
+ return self.server_key_dict[self.host_key_type]
+ except KeyError:
+ pass
+ return None
@staticmethod
def load_server_moduli(filename=None):
@@ -524,13 +889,31 @@ class Transport(threading.Thread, ClosingContextManager):
.. note:: This has no effect when used in client mode.
"""
- pass
+ Transport._modulus_pack = ModulusPack()
+ # places to look for the openssh "moduli" file
+ file_list = ["/etc/ssh/moduli", "/usr/local/etc/moduli"]
+ if filename is not None:
+ file_list.insert(0, filename)
+ for fn in file_list:
+ try:
+ Transport._modulus_pack.read_file(fn)
+ return True
+ except IOError:
+ pass
+ # none succeeded
+ Transport._modulus_pack = None
+ return False
def close(self):
"""
Close this session, and any open channels that are tied to it.
"""
- pass
+ if not self.active:
+ return
+ self.stop_thread()
+ for chan in list(self._channels.values()):
+ chan._unlink()
+ self.sock.close()
def get_remote_server_key(self):
"""
@@ -545,7 +928,9 @@ class Transport(threading.Thread, ClosingContextManager):
:return: public key (`.PKey`) of the remote server
"""
- pass
+ if (not self.active) or (not self.initial_kex_done):
+ raise SSHException("No existing session")
+ return self.host_key
def is_active(self):
"""
@@ -555,10 +940,11 @@ class Transport(threading.Thread, ClosingContextManager):
True if the session is still active (open); False if the session is
closed
"""
- pass
+ return self.active
- def open_session(self, window_size=None, max_packet_size=None, timeout=None
- ):
+ def open_session(
+ self, window_size=None, max_packet_size=None, timeout=None
+ ):
"""
Request a new channel to the server, of type ``"session"``. This is
just an alias for calling `open_channel` with an argument of
@@ -584,7 +970,12 @@ class Transport(threading.Thread, ClosingContextManager):
.. versionchanged:: 1.15
Added the ``window_size`` and ``max_packet_size`` arguments.
"""
- pass
+ return self.open_channel(
+ "session",
+ window_size=window_size,
+ max_packet_size=max_packet_size,
+ timeout=timeout,
+ )
def open_x11_channel(self, src_addr=None):
"""
@@ -600,7 +991,7 @@ class Transport(threading.Thread, ClosingContextManager):
`.SSHException` -- if the request is rejected or the session ends
prematurely
"""
- pass
+ return self.open_channel("x11", src_addr=src_addr)
def open_forward_agent_channel(self):
"""
@@ -614,7 +1005,7 @@ class Transport(threading.Thread, ClosingContextManager):
:raises: `.SSHException` --
if the request is rejected or the session ends prematurely
"""
- pass
+ return self.open_channel("auth-agent@openssh.com")
def open_forwarded_tcpip_channel(self, src_addr, dest_addr):
"""
@@ -626,10 +1017,17 @@ class Transport(threading.Thread, ClosingContextManager):
:param src_addr: originator's address
:param dest_addr: local (server) connected address
"""
- pass
+ return self.open_channel("forwarded-tcpip", dest_addr, src_addr)
- def open_channel(self, kind, dest_addr=None, src_addr=None, window_size
- =None, max_packet_size=None, timeout=None):
+ def open_channel(
+ self,
+ kind,
+ dest_addr=None,
+ src_addr=None,
+ window_size=None,
+ max_packet_size=None,
+ timeout=None,
+ ):
"""
Request a new channel to the server. `Channels <.Channel>` are
socket-like objects used for the actual transfer of data across the
@@ -665,7 +1063,56 @@ class Transport(threading.Thread, ClosingContextManager):
.. versionchanged:: 1.15
Added the ``window_size`` and ``max_packet_size`` arguments.
"""
- pass
+ if not self.active:
+ raise SSHException("SSH session not active")
+ timeout = self.channel_timeout if timeout is None else timeout
+ self.lock.acquire()
+ try:
+ window_size = self._sanitize_window_size(window_size)
+ max_packet_size = self._sanitize_packet_size(max_packet_size)
+ chanid = self._next_channel()
+ m = Message()
+ m.add_byte(cMSG_CHANNEL_OPEN)
+ m.add_string(kind)
+ m.add_int(chanid)
+ m.add_int(window_size)
+ m.add_int(max_packet_size)
+ if (kind == "forwarded-tcpip") or (kind == "direct-tcpip"):
+ m.add_string(dest_addr[0])
+ m.add_int(dest_addr[1])
+ m.add_string(src_addr[0])
+ m.add_int(src_addr[1])
+ elif kind == "x11":
+ m.add_string(src_addr[0])
+ m.add_int(src_addr[1])
+ chan = Channel(chanid)
+ self._channels.put(chanid, chan)
+ self.channel_events[chanid] = event = threading.Event()
+ self.channels_seen[chanid] = True
+ chan._set_transport(self)
+ chan._set_window(window_size, max_packet_size)
+ finally:
+ self.lock.release()
+ self._send_user_message(m)
+ start_ts = time.time()
+ while True:
+ event.wait(0.1)
+ if not self.active:
+ e = self.get_exception()
+ if e is None:
+ e = SSHException("Unable to open channel.")
+ raise e
+ if event.is_set():
+ break
+ elif start_ts + timeout < time.time():
+ raise SSHException("Timeout opening channel.")
+ chan = self._channels.get(chanid)
+ if chan is not None:
+ return chan
+ e = self.get_exception()
+ if e is None:
+ e = SSHException("Unable to open channel.")
+ raise e
def request_port_forward(self, address, port, handler=None):
"""
@@ -700,7 +1147,26 @@ class Transport(threading.Thread, ClosingContextManager):
:raises:
`.SSHException` -- if the server refused the TCP forward request
"""
- pass
+ if not self.active:
+ raise SSHException("SSH session not active")
+ port = int(port)
+ response = self.global_request(
+ "tcpip-forward", (address, port), wait=True
+ )
+ if response is None:
+ raise SSHException("TCP forwarding request denied")
+ if port == 0:
+ port = response.get_int()
+ if handler is None:
+
+ def default_handler(channel, src_addr, dest_addr_port):
+ # src_addr, src_port = src_addr_port
+ # dest_addr, dest_port = dest_addr_port
+ self._queue_incoming_channel(channel)
+
+ handler = default_handler
+ self._tcp_handler = handler
+ return port
def cancel_port_forward(self, address, port):
"""
@@ -711,7 +1177,10 @@ class Transport(threading.Thread, ClosingContextManager):
:param str address: the address to stop forwarding
:param int port: the port to stop forwarding
"""
- pass
+ if not self.active:
+ return
+ self._tcp_handler = None
+ self.global_request("cancel-tcpip-forward", (address, port), wait=True)
def open_sftp_client(self):
"""
@@ -723,7 +1192,7 @@ class Transport(threading.Thread, ClosingContextManager):
a new `.SFTPClient` referring to an sftp session (channel) across
this transport
"""
- pass
+ return SFTPClient.from_transport(self)
def send_ignore(self, byte_count=None):
"""
@@ -736,7 +1205,12 @@ class Transport(threading.Thread, ClosingContextManager):
the number of random bytes to send in the payload of the ignored
packet -- defaults to a random number from 10 to 41.
"""
- pass
+ m = Message()
+ m.add_byte(cMSG_IGNORE)
+ if byte_count is None:
+ byte_count = (byte_ord(os.urandom(1)) % 32) + 10
+ m.add_bytes(os.urandom(byte_count))
+ self._send_user_message(m)
def renegotiate_keys(self):
"""
@@ -751,7 +1225,18 @@ class Transport(threading.Thread, ClosingContextManager):
`.SSHException` -- if the key renegotiation failed (which causes
the session to end)
"""
- pass
+ self.completion_event = threading.Event()
+ self._send_kex_init()
+ while True:
+ self.completion_event.wait(0.1)
+ if not self.active:
+ e = self.get_exception()
+ if e is not None:
+ raise e
+ raise SSHException("Negotiation failed.")
+ if self.completion_event.is_set():
+ break
+ return
def set_keepalive(self, interval):
"""
@@ -764,7 +1249,11 @@ class Transport(threading.Thread, ClosingContextManager):
seconds to wait before sending a keepalive packet (or
0 to disable keepalives).
"""
- pass
+
+ def _request(x=weakref.proxy(self)):
+ return x.global_request("keepalive@lag.net", wait=False)
+
+ self.packetizer.set_keepalive(interval, _request)
def global_request(self, kind, data=None, wait=True):
"""
@@ -783,7 +1272,25 @@ class Transport(threading.Thread, ClosingContextManager):
successful (or an empty `.Message` if ``wait`` was ``False``);
``None`` if the request was denied.
"""
- pass
+ if wait:
+ self.completion_event = threading.Event()
+ m = Message()
+ m.add_byte(cMSG_GLOBAL_REQUEST)
+ m.add_string(kind)
+ m.add_boolean(wait)
+ if data is not None:
+ m.add(*data)
+ self._log(DEBUG, 'Sending global request "{}"'.format(kind))
+ self._send_user_message(m)
+ if not wait:
+ return None
+ while True:
+ self.completion_event.wait(0.1)
+ if not self.active:
+ return None
+ if self.completion_event.is_set():
+ break
+ return self.global_response
def accept(self, timeout=None):
"""
@@ -795,11 +1302,33 @@ class Transport(threading.Thread, ClosingContextManager):
seconds to wait for a channel, or ``None`` to wait forever
:return: a new `.Channel` opened by the client
"""
- pass
-
- def connect(self, hostkey=None, username='', password=None, pkey=None,
- gss_host=None, gss_auth=False, gss_kex=False, gss_deleg_creds=True,
- gss_trust_dns=True):
+ self.lock.acquire()
+ try:
+ if len(self.server_accepts) > 0:
+ chan = self.server_accepts.pop(0)
+ else:
+ self.server_accept_cv.wait(timeout)
+ if len(self.server_accepts) > 0:
+ chan = self.server_accepts.pop(0)
+ else:
+ # timeout
+ chan = None
+ finally:
+ self.lock.release()
+ return chan
+
+ def connect(
+ self,
+ hostkey=None,
+ username="",
+ password=None,
+ pkey=None,
+ gss_host=None,
+ gss_auth=False,
+ gss_kex=False,
+ gss_deleg_creds=True,
+ gss_trust_dns=True,
+ ):
"""
Negotiate an SSH2 session, and optionally verify the server's host key
and authenticate using a password or private key. This is a shortcut
@@ -848,7 +1377,73 @@ class Transport(threading.Thread, ClosingContextManager):
.. versionchanged:: 2.3
Added the ``gss_trust_dns`` argument.
"""
- pass
+ if hostkey is not None:
+ # TODO: a more robust implementation would be to ask each key class
+ # for its nameS plural, and just use that.
+ # TODO: that could be used in a bunch of other spots too
+ if isinstance(hostkey, RSAKey):
+ self._preferred_keys = [
+ "rsa-sha2-512",
+ "rsa-sha2-256",
+ "ssh-rsa",
+ ]
+ else:
+ self._preferred_keys = [hostkey.get_name()]
+
+ self.set_gss_host(
+ gss_host=gss_host,
+ trust_dns=gss_trust_dns,
+ gssapi_requested=gss_kex or gss_auth,
+ )
+
+ self.start_client()
+
+ # check host key if we were given one
+ # If GSS-API Key Exchange was performed, we are not required to check
+ # the host key.
+ if (hostkey is not None) and not gss_kex:
+ key = self.get_remote_server_key()
+ if (
+ key.get_name() != hostkey.get_name()
+ or key.asbytes() != hostkey.asbytes()
+ ):
+ self._log(DEBUG, "Bad host key from server")
+ self._log(
+ DEBUG,
+ "Expected: {}: {}".format(
+ hostkey.get_name(), repr(hostkey.asbytes())
+ ),
+ )
+ self._log(
+ DEBUG,
+ "Got : {}: {}".format(
+ key.get_name(), repr(key.asbytes())
+ ),
+ )
+ raise SSHException("Bad host key from server")
+ self._log(
+ DEBUG, "Host key verified ({})".format(hostkey.get_name())
+ )
+
+ if (pkey is not None) or (password is not None) or gss_auth or gss_kex:
+ if gss_auth:
+ self._log(
+ DEBUG, "Attempting GSS-API auth... (gssapi-with-mic)"
+ ) # noqa
+ self.auth_gssapi_with_mic(
+ username, self.gss_host, gss_deleg_creds
+ )
+ elif gss_kex:
+ self._log(DEBUG, "Attempting GSS-API auth... (gssapi-keyex)")
+ self.auth_gssapi_keyex(username)
+ elif pkey is not None:
+ self._log(DEBUG, "Attempting public-key auth...")
+ self.auth_publickey(username, pkey)
+ else:
+ self._log(DEBUG, "Attempting password auth...")
+ self.auth_password(username, password)
+
+ return
def get_exception(self):
"""
@@ -862,7 +1457,13 @@ class Transport(threading.Thread, ClosingContextManager):
.. versionadded:: 1.1
"""
- pass
+ self.lock.acquire()
+ try:
+ e = self.saved_exception
+ self.saved_exception = None
+ return e
+ finally:
+ self.lock.release()
def set_subsystem_handler(self, name, handler, *args, **kwargs):
"""
@@ -878,7 +1479,11 @@ class Transport(threading.Thread, ClosingContextManager):
:param handler:
subclass of `.SubsystemHandler` that handles this subsystem.
"""
- pass
+ try:
+ self.lock.acquire()
+ self.subsystem_table[name] = (handler, args, kwargs)
+ finally:
+ self.lock.release()
def is_authenticated(self):
"""
@@ -889,7 +1494,11 @@ class Transport(threading.Thread, ClosingContextManager):
successfully; False if authentication failed and/or the session is
closed.
"""
- pass
+ return (
+ self.active
+ and self.auth_handler is not None
+ and self.auth_handler.is_authenticated()
+ )
def get_username(self):
"""
@@ -899,7 +1508,9 @@ class Transport(threading.Thread, ClosingContextManager):
:return: username that was authenticated (a `str`), or ``None``.
"""
- pass
+ if not self.active or (self.auth_handler is None):
+ return None
+ return self.auth_handler.get_username()
def get_banner(self):
"""
@@ -910,7 +1521,9 @@ class Transport(threading.Thread, ClosingContextManager):
.. versionadded:: 1.13
"""
- pass
+ if not self.active or (self.auth_handler is None):
+ return None
+ return self.auth_handler.banner
def auth_none(self, username):
"""
@@ -933,7 +1546,12 @@ class Transport(threading.Thread, ClosingContextManager):
.. versionadded:: 1.5
"""
- pass
+ if (not self.active) or (not self.initial_kex_done):
+ raise SSHException("No existing session")
+ my_event = threading.Event()
+ self.auth_handler = AuthHandler(self)
+ self.auth_handler.auth_none(username, my_event)
+ return self.auth_handler.wait_for_response(my_event)
def auth_password(self, username, password, event=None, fallback=True):
"""
@@ -982,7 +1600,43 @@ class Transport(threading.Thread, ClosingContextManager):
event was passed in)
:raises: `.SSHException` -- if there was a network error
"""
- pass
+ if (not self.active) or (not self.initial_kex_done):
+ # we should never try to send the password unless we're on a secure
+ # link
+ raise SSHException("No existing session")
+ if event is None:
+ my_event = threading.Event()
+ else:
+ my_event = event
+ self.auth_handler = AuthHandler(self)
+ self.auth_handler.auth_password(username, password, my_event)
+ if event is not None:
+ # caller wants to wait for event themselves
+ return []
+ try:
+ return self.auth_handler.wait_for_response(my_event)
+ except BadAuthenticationType as e:
+ # if password auth isn't allowed, but keyboard-interactive *is*,
+ # try to fudge it
+ if not fallback or ("keyboard-interactive" not in e.allowed_types):
+ raise
+ try:
+
+ def handler(title, instructions, fields):
+ if len(fields) > 1:
+ raise SSHException("Fallback authentication failed.")
+ if len(fields) == 0:
+ # for some reason, at least on os x, a 2nd request will
+ # be made with zero fields requested. maybe it's just
+ # to try to fake out automated scripting of the exact
+ # type we're doing here. *shrug* :)
+ return []
+ return [password]
+
+ return self.auth_interactive(username, handler)
+ except SSHException:
+ # attempt failed; just raise the original exception
+ raise e
def auth_publickey(self, username, key, event=None):
"""
@@ -1019,9 +1673,21 @@ class Transport(threading.Thread, ClosingContextManager):
event was passed in)
:raises: `.SSHException` -- if there was a network error
"""
- pass
+ if (not self.active) or (not self.initial_kex_done):
+ # we should never try to authenticate unless we're on a secure link
+ raise SSHException("No existing session")
+ if event is None:
+ my_event = threading.Event()
+ else:
+ my_event = event
+ self.auth_handler = AuthHandler(self)
+ self.auth_handler.auth_publickey(username, key, my_event)
+ if event is not None:
+ # caller wants to wait for event themselves
+ return []
+ return self.auth_handler.wait_for_response(my_event)
- def auth_interactive(self, username, handler, submethods=''):
+ def auth_interactive(self, username, handler, submethods=""):
"""
Authenticate to the server interactively. A handler is used to answer
arbitrary questions from the server. On many servers, this is just a
@@ -1064,16 +1730,38 @@ class Transport(threading.Thread, ClosingContextManager):
.. versionadded:: 1.5
"""
- pass
+ if (not self.active) or (not self.initial_kex_done):
+ # we should never try to authenticate unless we're on a secure link
+ raise SSHException("No existing session")
+ my_event = threading.Event()
+ self.auth_handler = AuthHandler(self)
+ self.auth_handler.auth_interactive(
+ username, handler, my_event, submethods
+ )
+ return self.auth_handler.wait_for_response(my_event)
- def auth_interactive_dumb(self, username, handler=None, submethods=''):
+ def auth_interactive_dumb(self, username, handler=None, submethods=""):
"""
Authenticate to the server interactively but dumber.
Just print the prompt and / or instructions to stdout and send back
the response. This is good for situations where partial auth is
achieved by key and then the user has to enter a 2fac token.
"""
- pass
+
+ if not handler:
+
+ def handler(title, instructions, prompt_list):
+ answers = []
+ if title:
+ print(title.strip())
+ if instructions:
+ print(instructions.strip())
+ for prompt, show_input in prompt_list:
+ print(prompt.strip(), end=" ")
+ answers.append(input())
+ return answers
+
+ return self.auth_interactive(username, handler, submethods)
def auth_gssapi_with_mic(self, username, gss_host, gss_deleg_creds):
"""
@@ -1091,7 +1779,15 @@ class Transport(threading.Thread, ClosingContextManager):
event was passed in)
:raises: `.SSHException` -- if there was a network error
"""
- pass
+ if (not self.active) or (not self.initial_kex_done):
+ # we should never try to authenticate unless we're on a secure link
+ raise SSHException("No existing session")
+ my_event = threading.Event()
+ self.auth_handler = AuthHandler(self)
+ self.auth_handler.auth_gssapi_with_mic(
+ username, gss_host, gss_deleg_creds, my_event
+ )
+ return self.auth_handler.wait_for_response(my_event)
def auth_gssapi_keyex(self, username):
"""
@@ -1108,7 +1804,13 @@ class Transport(threading.Thread, ClosingContextManager):
if the authentication failed (and no event was passed in)
:raises: `.SSHException` -- if there was a network error
"""
- pass
+ if (not self.active) or (not self.initial_kex_done):
+ # we should never try to authenticate unless we're on a secure link
+ raise SSHException("No existing session")
+ my_event = threading.Event()
+ self.auth_handler = AuthHandler(self)
+ self.auth_handler.auth_gssapi_keyex(username, my_event)
+ return self.auth_handler.wait_for_response(my_event)
def set_log_channel(self, name):
"""
@@ -1121,7 +1823,9 @@ class Transport(threading.Thread, ClosingContextManager):
.. versionadded:: 1.1
"""
- pass
+ self.log_name = name
+ self.logger = util.get_logger(name)
+ self.packetizer.set_log(self.logger)
def get_log_channel(self):
"""
@@ -1131,7 +1835,7 @@ class Transport(threading.Thread, ClosingContextManager):
.. versionadded:: 1.2
"""
- pass
+ return self.log_name
def set_hexdump(self, hexdump):
"""
@@ -1143,7 +1847,7 @@ class Transport(threading.Thread, ClosingContextManager):
``True`` to log protocol traffix (in hex) to the log; ``False``
otherwise.
"""
- pass
+ self.packetizer.set_hexdump(hexdump)
def get_hexdump(self):
"""
@@ -1154,7 +1858,7 @@ class Transport(threading.Thread, ClosingContextManager):
.. versionadded:: 1.4
"""
- pass
+ return self.packetizer.get_hexdump()
def use_compression(self, compress=True):
"""
@@ -1168,7 +1872,10 @@ class Transport(threading.Thread, ClosingContextManager):
.. versionadded:: 1.5.2
"""
- pass
+ if compress:
+ self._preferred_compression = ("zlib@openssh.com", "zlib", "none")
+ else:
+ self._preferred_compression = ("none",)
def getpeername(self):
"""
@@ -1182,42 +1889,194 @@ class Transport(threading.Thread, ClosingContextManager):
the address of the remote host, if known, as a ``(str, int)``
tuple.
"""
- pass
+ gp = getattr(self.sock, "getpeername", None)
+ if gp is None:
+ return "unknown", 0
+ return gp()
+
+ def stop_thread(self):
+ self.active = False
+ self.packetizer.close()
+ # Keep trying to join() our main thread, quickly, until:
+ # * We join()ed successfully (self.is_alive() == False)
+ # * Or it looks like we've hit issue #520 (socket.recv hitting some
+ # race condition preventing it from timing out correctly), wherein
+ # our socket and packetizer are both closed (but where we'd
+ # otherwise be sitting forever on that recv()).
+ while (
+ self.is_alive()
+ and self is not threading.current_thread()
+ and not self.sock._closed
+ and not self.packetizer.closed
+ ):
+ self.join(0.1)
+
+ # internals...
+
+ # TODO 4.0: make a public alias for this because multiple other classes
+ # already explicitly rely on it...or just rewrite logging :D
+ def _log(self, level, msg, *args):
+ if issubclass(type(msg), list):
+ for m in msg:
+ self.logger.log(level, m)
+ else:
+ self.logger.log(level, msg, *args)
def _get_modulus_pack(self):
"""used by KexGex to find primes for group exchange"""
- pass
+ return self._modulus_pack
def _next_channel(self):
"""you are holding the lock"""
- pass
+ chanid = self._channel_counter
+ while self._channels.get(chanid) is not None:
+ self._channel_counter = (self._channel_counter + 1) & 0xFFFFFF
+ chanid = self._channel_counter
+ self._channel_counter = (self._channel_counter + 1) & 0xFFFFFF
+ return chanid
def _unlink_channel(self, chanid):
"""used by a Channel to remove itself from the active channel list"""
- pass
+ self._channels.delete(chanid)
+
+ def _send_message(self, data):
+ self.packetizer.send_message(data)
def _send_user_message(self, data):
"""
send a message, but block if we're in key negotiation. this is used
for user-initiated requests.
"""
- pass
+ start = time.time()
+ while True:
+ self.clear_to_send.wait(0.1)
+ if not self.active:
+ self._log(
+ DEBUG, "Dropping user packet because connection is dead."
+ ) # noqa
+ return
+ self.clear_to_send_lock.acquire()
+ if self.clear_to_send.is_set():
+ break
+ self.clear_to_send_lock.release()
+ if time.time() > start + self.clear_to_send_timeout:
+ raise SSHException(
+ "Key-exchange timed out waiting for key negotiation"
+ ) # noqa
+ try:
+ self._send_message(data)
+ finally:
+ self.clear_to_send_lock.release()
def _set_K_H(self, k, h):
"""
Used by a kex obj to set the K (root key) and H (exchange hash).
"""
- pass
+ self.K = k
+ self.H = h
+ if self.session_id is None:
+ self.session_id = h
def _expect_packet(self, *ptypes):
"""
Used by a kex obj to register the next packet type it expects to see.
"""
- pass
+ self._expected_packet = tuple(ptypes)
+
+ def _verify_key(self, host_key, sig):
+ key = self._key_info[self.host_key_type](Message(host_key))
+ if key is None:
+ raise SSHException("Unknown host key type")
+ if not key.verify_ssh_sig(self.H, Message(sig)):
+ raise SSHException(
+ "Signature verification ({}) failed.".format(
+ self.host_key_type
+ )
+ ) # noqa
+ self.host_key = key
def _compute_key(self, id, nbytes):
"""id is 'A' - 'F' for the various keys used by ssh"""
- pass
+ m = Message()
+ m.add_mpint(self.K)
+ m.add_bytes(self.H)
+ m.add_byte(b(id))
+ m.add_bytes(self.session_id)
+ # Fallback to SHA1 for kex engines that fail to specify a hex
+ # algorithm, or for e.g. transport tests that don't run kexinit.
+ hash_algo = getattr(self.kex_engine, "hash_algo", None)
+ hash_select_msg = "kex engine {} specified hash_algo {!r}".format(
+ self.kex_engine.__class__.__name__, hash_algo
+ )
+ if hash_algo is None:
+ hash_algo = sha1
+ hash_select_msg += ", falling back to sha1"
+ if not hasattr(self, "_logged_hash_selection"):
+ self._log(DEBUG, hash_select_msg)
+ setattr(self, "_logged_hash_selection", True)
+ out = sofar = hash_algo(m.asbytes()).digest()
+ while len(out) < nbytes:
+ m = Message()
+ m.add_mpint(self.K)
+ m.add_bytes(self.H)
+ m.add_bytes(sofar)
+ digest = hash_algo(m.asbytes()).digest()
+ out += digest
+ sofar += digest
+ return out[:nbytes]
+
+ def _get_cipher(self, name, key, iv, operation):
+ if name not in self._cipher_info:
+ raise SSHException("Unknown client cipher " + name)
+ else:
+ cipher = Cipher(
+ self._cipher_info[name]["class"](key),
+ self._cipher_info[name]["mode"](iv),
+ backend=default_backend(),
+ )
+ if operation is self._ENCRYPT:
+ return cipher.encryptor()
+ else:
+ return cipher.decryptor()
+
+ def _set_forward_agent_handler(self, handler):
+ if handler is None:
+
+ def default_handler(channel):
+ self._queue_incoming_channel(channel)
+
+ self._forward_agent_handler = default_handler
+ else:
+ self._forward_agent_handler = handler
+
+ def _set_x11_handler(self, handler):
+ # only called if a channel has turned on x11 forwarding
+ if handler is None:
+ # by default, use the same mechanism as accept()
+ def default_handler(channel, src_addr_port):
+ self._queue_incoming_channel(channel)
+
+ self._x11_handler = default_handler
+ else:
+ self._x11_handler = handler
+
+ def _queue_incoming_channel(self, channel):
+ self.lock.acquire()
+ try:
+ self.server_accepts.append(channel)
+ self.server_accept_cv.notify()
+ finally:
+ self.lock.release()
+
+ def _sanitize_window_size(self, window_size):
+ if window_size is None:
+ window_size = self.default_window_size
+ return clamp_value(MIN_WINDOW_SIZE, window_size, MAX_WINDOW_SIZE)
+
+ def _sanitize_packet_size(self, max_packet_size):
+ if max_packet_size is None:
+ max_packet_size = self.default_max_packet_size
+ return clamp_value(MIN_PACKET_SIZE, max_packet_size, MAX_WINDOW_SIZE)
def _ensure_authed(self, ptype, message):
"""
@@ -1229,7 +2088,33 @@ class Transport(threading.Thread, ClosingContextManager):
Otherwise (client mode, authed, or pre-auth message) returns None.
"""
- pass
+ if (
+ not self.server_mode
+ or ptype <= HIGHEST_USERAUTH_MESSAGE_ID
+ or self.is_authenticated()
+ ):
+ return None
+ # WELP. We must be dealing with someone trying to do non-auth things
+ # without being authed. Tell them off, based on message class.
+ reply = Message()
+ # Global requests have no details, just failure.
+ if ptype == MSG_GLOBAL_REQUEST:
+ reply.add_byte(cMSG_REQUEST_FAILURE)
+ # Channel opens let us reject w/ a specific type + message.
+ elif ptype == MSG_CHANNEL_OPEN:
+ kind = message.get_text() # noqa
+ chanid = message.get_int()
+ reply.add_byte(cMSG_CHANNEL_OPEN_FAILURE)
+ reply.add_int(chanid)
+ reply.add_int(OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED)
+ reply.add_string("")
+ reply.add_string("en")
+ # NOTE: Post-open channel messages do not need checking; the above will
+ # reject attempts to open channels, meaning that even if a malicious
+ # user tries to send a MSG_CHANNEL_REQUEST, it will simply fall under
+ # the logic that handles unknown channel IDs (as the channel list will
+ # be empty.)
+ return reply
def _enforce_strict_kex(self, ptype):
"""
@@ -1239,32 +2124,1001 @@ class Transport(threading.Thread, ClosingContextManager):
messages; it does not interrogate ``ptype`` besides using it to log
more accurately.
"""
- pass
+ if self.agreed_on_strict_kex and not self.initial_kex_done:
+ name = MSG_NAMES.get(ptype, f"msg {ptype}")
+ raise MessageOrderError(
+ f"In strict-kex mode, but was sent {name!r}!"
+ )
+
+ def run(self):
+ # (use the exposed "run" method, because if we specify a thread target
+ # of a private method, threading.Thread will keep a reference to it
+ # indefinitely, creating a GC cycle and not letting Transport ever be
+ # GC'd. it's a bug in Thread.)
+
+ # Hold reference to 'sys' so we can test sys.modules to detect
+ # interpreter shutdown.
+ self.sys = sys
+
+ # active=True occurs before the thread is launched, to avoid a race
+ _active_threads.append(self)
+ tid = hex(id(self) & xffffffff)
+ if self.server_mode:
+ self._log(DEBUG, "starting thread (server mode): {}".format(tid))
+ else:
+ self._log(DEBUG, "starting thread (client mode): {}".format(tid))
+ try:
+ try:
+ self.packetizer.write_all(b(self.local_version + "\r\n"))
+ self._log(
+ DEBUG,
+ "Local version/idstring: {}".format(self.local_version),
+ ) # noqa
+ self._check_banner()
+ # The above is actually very much part of the handshake, but
+ # sometimes the banner can be read but the machine is not
+ # responding, for example when the remote ssh daemon is loaded
+ # in to memory but we can not read from the disk/spawn a new
+ # shell.
+ # Make sure we can specify a timeout for the initial handshake.
+ # Re-use the banner timeout for now.
+ self.packetizer.start_handshake(self.handshake_timeout)
+ self._send_kex_init()
+ self._expect_packet(MSG_KEXINIT)
+
+ while self.active:
+ if self.packetizer.need_rekey() and not self.in_kex:
+ self._send_kex_init()
+ try:
+ ptype, m = self.packetizer.read_message()
+ except NeedRekeyException:
+ continue
+ if ptype == MSG_IGNORE:
+ self._enforce_strict_kex(ptype)
+ continue
+ elif ptype == MSG_DISCONNECT:
+ self._parse_disconnect(m)
+ break
+ elif ptype == MSG_DEBUG:
+ self._enforce_strict_kex(ptype)
+ self._parse_debug(m)
+ continue
+ if len(self._expected_packet) > 0:
+ if ptype not in self._expected_packet:
+ exc_class = SSHException
+ if self.agreed_on_strict_kex:
+ exc_class = MessageOrderError
+ raise exc_class(
+ "Expecting packet from {!r}, got {:d}".format(
+ self._expected_packet, ptype
+ )
+ ) # noqa
+ self._expected_packet = tuple()
+ # These message IDs indicate key exchange & will differ
+ # depending on exact exchange algorithm
+ if (ptype >= 30) and (ptype <= 41):
+ self.kex_engine.parse_next(ptype, m)
+ continue
+
+ if ptype in self._handler_table:
+ error_msg = self._ensure_authed(ptype, m)
+ if error_msg:
+ self._send_message(error_msg)
+ else:
+ self._handler_table[ptype](m)
+ elif ptype in self._channel_handler_table:
+ chanid = m.get_int()
+ chan = self._channels.get(chanid)
+ if chan is not None:
+ self._channel_handler_table[ptype](chan, m)
+ elif chanid in self.channels_seen:
+ self._log(
+ DEBUG,
+ "Ignoring message for dead channel {:d}".format( # noqa
+ chanid
+ ),
+ )
+ else:
+ self._log(
+ ERROR,
+ "Channel request for unknown channel {:d}".format( # noqa
+ chanid
+ ),
+ )
+ break
+ elif (
+ self.auth_handler is not None
+ and ptype in self.auth_handler._handler_table
+ ):
+ handler = self.auth_handler._handler_table[ptype]
+ handler(m)
+ if len(self._expected_packet) > 0:
+ continue
+ else:
+ # Respond with "I don't implement this particular
+ # message type" message (unless the message type was
+ # itself literally MSG_UNIMPLEMENTED, in which case, we
+ # just shut up to avoid causing a useless loop).
+ name = MSG_NAMES[ptype]
+ warning = "Oops, unhandled type {} ({!r})".format(
+ ptype, name
+ )
+ self._log(WARNING, warning)
+ if ptype != MSG_UNIMPLEMENTED:
+ msg = Message()
+ msg.add_byte(cMSG_UNIMPLEMENTED)
+ msg.add_int(m.seqno)
+ self._send_message(msg)
+ self.packetizer.complete_handshake()
+ except SSHException as e:
+ self._log(
+ ERROR,
+ "Exception ({}): {}".format(
+ "server" if self.server_mode else "client", e
+ ),
+ )
+ self._log(ERROR, util.tb_strings())
+ self.saved_exception = e
+ except EOFError as e:
+ self._log(DEBUG, "EOF in transport thread")
+ self.saved_exception = e
+ except socket.error as e:
+ if type(e.args) is tuple:
+ if e.args:
+ emsg = "{} ({:d})".format(e.args[1], e.args[0])
+ else: # empty tuple, e.g. socket.timeout
+ emsg = str(e) or repr(e)
+ else:
+ emsg = e.args
+ self._log(ERROR, "Socket exception: " + emsg)
+ self.saved_exception = e
+ except Exception as e:
+ self._log(ERROR, "Unknown exception: " + str(e))
+ self._log(ERROR, util.tb_strings())
+ self.saved_exception = e
+ _active_threads.remove(self)
+ for chan in list(self._channels.values()):
+ chan._unlink()
+ if self.active:
+ self.active = False
+ self.packetizer.close()
+ if self.completion_event is not None:
+ self.completion_event.set()
+ if self.auth_handler is not None:
+ self.auth_handler.abort()
+ for event in self.channel_events.values():
+ event.set()
+ try:
+ self.lock.acquire()
+ self.server_accept_cv.notify()
+ finally:
+ self.lock.release()
+ self.sock.close()
+ except:
+ # Don't raise spurious 'NoneType has no attribute X' errors when we
+ # wake up during interpreter shutdown. Or rather -- raise
+ # everything *if* sys.modules (used as a convenient sentinel)
+ # appears to still exist.
+ if self.sys.modules is not None:
+ raise
+
+ def _log_agreement(self, which, local, remote):
+ # Log useful, non-duplicative line re: an agreed-upon algorithm.
+ # Old code implied algorithms could be asymmetrical (different for
+ # inbound vs outbound) so we preserve that possibility.
+ msg = "{}: ".format(which)
+ if local == remote:
+ msg += local
+ else:
+ msg += "local={}, remote={}".format(local, remote)
+ self._log(DEBUG, msg)
+
+ # protocol stages
+
+ def _negotiate_keys(self, m):
+ # throws SSHException on anything unusual
+ self.clear_to_send_lock.acquire()
+ try:
+ self.clear_to_send.clear()
+ finally:
+ self.clear_to_send_lock.release()
+ if self.local_kex_init is None:
+ # remote side wants to renegotiate
+ self._send_kex_init()
+ self._parse_kex_init(m)
+ self.kex_engine.start_kex()
+
+ def _check_banner(self):
+ # this is slow, but we only have to do it once
+ for i in range(100):
+ # give them 15 seconds for the first line, then just 2 seconds
+ # each additional line. (some sites have very high latency.)
+ if i == 0:
+ timeout = self.banner_timeout
+ else:
+ timeout = 2
+ try:
+ buf = self.packetizer.readline(timeout)
+ except ProxyCommandFailure:
+ raise
+ except Exception as e:
+ raise SSHException(
+ "Error reading SSH protocol banner" + str(e)
+ )
+ if buf[:4] == "SSH-":
+ break
+ self._log(DEBUG, "Banner: " + buf)
+ if buf[:4] != "SSH-":
+ raise SSHException('Indecipherable protocol version "' + buf + '"')
+ # save this server version string for later
+ self.remote_version = buf
+ self._log(DEBUG, "Remote version/idstring: {}".format(buf))
+ # pull off any attached comment
+ # NOTE: comment used to be stored in a variable and then...never used.
+ # since 2003. ca 877cd974b8182d26fa76d566072917ea67b64e67
+ i = buf.find(" ")
+ if i >= 0:
+ buf = buf[:i]
+ # parse out version string and make sure it matches
+ segs = buf.split("-", 2)
+ if len(segs) < 3:
+ raise SSHException("Invalid SSH banner")
+ version = segs[1]
+ client = segs[2]
+ if version != "1.99" and version != "2.0":
+ msg = "Incompatible version ({} instead of 2.0)"
+ raise IncompatiblePeer(msg.format(version))
+ msg = "Connected (version {}, client {})".format(version, client)
+ self._log(INFO, msg)
def _send_kex_init(self):
"""
announce to the other side that we'd like to negotiate keys, and what
kind of key negotiation we support.
"""
- pass
+ self.clear_to_send_lock.acquire()
+ try:
+ self.clear_to_send.clear()
+ finally:
+ self.clear_to_send_lock.release()
+ self.gss_kex_used = False
+ self.in_kex = True
+ kex_algos = list(self.preferred_kex)
+ if self.server_mode:
+ mp_required_prefix = "diffie-hellman-group-exchange-sha"
+ kex_mp = [k for k in kex_algos if k.startswith(mp_required_prefix)]
+ if (self._modulus_pack is None) and (len(kex_mp) > 0):
+ # can't do group-exchange if we don't have a pack of potential
+ # primes
+ pkex = [
+ k
+ for k in self.get_security_options().kex
+ if not k.startswith(mp_required_prefix)
+ ]
+ self.get_security_options().kex = pkex
+ available_server_keys = list(
+ filter(
+ list(self.server_key_dict.keys()).__contains__,
+ # TODO: ensure tests will catch if somebody streamlines
+ # this by mistake - case is the admittedly silly one where
+ # the only calls to add_server_key() contain keys which
+ # were filtered out of the below via disabled_algorithms.
+ # If this is streamlined, we would then be allowing the
+ # disabled algorithm(s) for hostkey use
+ # TODO: honestly this prob just wants to get thrown out
+ # when we make kex configuration more straightforward
+ self.preferred_keys,
+ )
+ )
+ else:
+ available_server_keys = self.preferred_keys
+ # Signal support for MSG_EXT_INFO so server will send it to us.
+ # NOTE: doing this here handily means we don't even consider this
+ # value when agreeing on real kex algo to use (which is a common
+ # pitfall when adding this apparently).
+ kex_algos.append("ext-info-c")
+
+ # Similar to ext-info, but used in both server modes, so done outside
+ # of above if/else.
+ if self.advertise_strict_kex:
+ which = "s" if self.server_mode else "c"
+ kex_algos.append(f"kex-strict-{which}-v00@openssh.com")
+
+ m = Message()
+ m.add_byte(cMSG_KEXINIT)
+ m.add_bytes(os.urandom(16))
+ m.add_list(kex_algos)
+ m.add_list(available_server_keys)
+ m.add_list(self.preferred_ciphers)
+ m.add_list(self.preferred_ciphers)
+ m.add_list(self.preferred_macs)
+ m.add_list(self.preferred_macs)
+ m.add_list(self.preferred_compression)
+ m.add_list(self.preferred_compression)
+ m.add_string(bytes())
+ m.add_string(bytes())
+ m.add_boolean(False)
+ m.add_int(0)
+ # save a copy for later (needed to compute a hash)
+ self.local_kex_init = self._latest_kex_init = m.asbytes()
+ self._send_message(m)
+
+ def _really_parse_kex_init(self, m, ignore_first_byte=False):
+ parsed = {}
+ if ignore_first_byte:
+ m.get_byte()
+ m.get_bytes(16) # cookie, discarded
+ parsed["kex_algo_list"] = m.get_list()
+ parsed["server_key_algo_list"] = m.get_list()
+ parsed["client_encrypt_algo_list"] = m.get_list()
+ parsed["server_encrypt_algo_list"] = m.get_list()
+ parsed["client_mac_algo_list"] = m.get_list()
+ parsed["server_mac_algo_list"] = m.get_list()
+ parsed["client_compress_algo_list"] = m.get_list()
+ parsed["server_compress_algo_list"] = m.get_list()
+ parsed["client_lang_list"] = m.get_list()
+ parsed["server_lang_list"] = m.get_list()
+ parsed["kex_follows"] = m.get_boolean()
+ m.get_int() # unused
+ return parsed
+
+ def _get_latest_kex_init(self):
+ return self._really_parse_kex_init(
+ Message(self._latest_kex_init),
+ ignore_first_byte=True,
+ )
+
+ def _parse_kex_init(self, m):
+ parsed = self._really_parse_kex_init(m)
+ kex_algo_list = parsed["kex_algo_list"]
+ server_key_algo_list = parsed["server_key_algo_list"]
+ client_encrypt_algo_list = parsed["client_encrypt_algo_list"]
+ server_encrypt_algo_list = parsed["server_encrypt_algo_list"]
+ client_mac_algo_list = parsed["client_mac_algo_list"]
+ server_mac_algo_list = parsed["server_mac_algo_list"]
+ client_compress_algo_list = parsed["client_compress_algo_list"]
+ server_compress_algo_list = parsed["server_compress_algo_list"]
+ client_lang_list = parsed["client_lang_list"]
+ server_lang_list = parsed["server_lang_list"]
+ kex_follows = parsed["kex_follows"]
+
+ self._log(DEBUG, "=== Key exchange possibilities ===")
+ for prefix, value in (
+ ("kex algos", kex_algo_list),
+ ("server key", server_key_algo_list),
+ # TODO: shouldn't these two lines say "cipher" to match usual
+ # terminology (including elsewhere in paramiko!)?
+ ("client encrypt", client_encrypt_algo_list),
+ ("server encrypt", server_encrypt_algo_list),
+ ("client mac", client_mac_algo_list),
+ ("server mac", server_mac_algo_list),
+ ("client compress", client_compress_algo_list),
+ ("server compress", server_compress_algo_list),
+ ("client lang", client_lang_list),
+ ("server lang", server_lang_list),
+ ):
+ if value == [""]:
+ value = ["<none>"]
+ value = ", ".join(value)
+ self._log(DEBUG, "{}: {}".format(prefix, value))
+ self._log(DEBUG, "kex follows: {}".format(kex_follows))
+ self._log(DEBUG, "=== Key exchange agreements ===")
+
+ # Record, and strip out, ext-info and/or strict-kex non-algorithms
+ self._remote_ext_info = None
+ self._remote_strict_kex = None
+ to_pop = []
+ for i, algo in enumerate(kex_algo_list):
+ if algo.startswith("ext-info-"):
+ self._remote_ext_info = algo
+ to_pop.insert(0, i)
+ elif algo.startswith("kex-strict-"):
+ # NOTE: this is what we are expecting from the /remote/ end.
+ which = "c" if self.server_mode else "s"
+ expected = f"kex-strict-{which}-v00@openssh.com"
+ # Set strict mode if agreed.
+ self.agreed_on_strict_kex = (
+ algo == expected and self.advertise_strict_kex
+ )
+ self._log(
+ DEBUG, f"Strict kex mode: {self.agreed_on_strict_kex}"
+ )
+ to_pop.insert(0, i)
+ for i in to_pop:
+ kex_algo_list.pop(i)
+
+ # CVE mitigation: expect zeroed-out seqno anytime we are performing kex
+ # init phase, if strict mode was negotiated.
+ if (
+ self.agreed_on_strict_kex
+ and not self.initial_kex_done
+ and m.seqno != 0
+ ):
+ raise MessageOrderError(
+ "In strict-kex mode, but KEXINIT was not the first packet!"
+ )
+
+ # as a server, we pick the first item in the client's list that we
+ # support.
+ # as a client, we pick the first item in our list that the server
+ # supports.
+ if self.server_mode:
+ agreed_kex = list(
+ filter(self.preferred_kex.__contains__, kex_algo_list)
+ )
+ else:
+ agreed_kex = list(
+ filter(kex_algo_list.__contains__, self.preferred_kex)
+ )
+ if len(agreed_kex) == 0:
+ # TODO: do an auth-overhaul style aggregate exception here?
+ # TODO: would let us streamline log output & show all failures up
+ # front
+ raise IncompatiblePeer(
+ "Incompatible ssh peer (no acceptable kex algorithm)"
+ ) # noqa
+ self.kex_engine = self._kex_info[agreed_kex[0]](self)
+ self._log(DEBUG, "Kex: {}".format(agreed_kex[0]))
+
+ if self.server_mode:
+ available_server_keys = list(
+ filter(
+ list(self.server_key_dict.keys()).__contains__,
+ self.preferred_keys,
+ )
+ )
+ agreed_keys = list(
+ filter(
+ available_server_keys.__contains__, server_key_algo_list
+ )
+ )
+ else:
+ agreed_keys = list(
+ filter(server_key_algo_list.__contains__, self.preferred_keys)
+ )
+ if len(agreed_keys) == 0:
+ raise IncompatiblePeer(
+ "Incompatible ssh peer (no acceptable host key)"
+ ) # noqa
+ self.host_key_type = agreed_keys[0]
+ if self.server_mode and (self.get_server_key() is None):
+ raise IncompatiblePeer(
+ "Incompatible ssh peer (can't match requested host key type)"
+ ) # noqa
+ self._log_agreement("HostKey", agreed_keys[0], agreed_keys[0])
+
+ if self.server_mode:
+ agreed_local_ciphers = list(
+ filter(
+ self.preferred_ciphers.__contains__,
+ server_encrypt_algo_list,
+ )
+ )
+ agreed_remote_ciphers = list(
+ filter(
+ self.preferred_ciphers.__contains__,
+ client_encrypt_algo_list,
+ )
+ )
+ else:
+ agreed_local_ciphers = list(
+ filter(
+ client_encrypt_algo_list.__contains__,
+ self.preferred_ciphers,
+ )
+ )
+ agreed_remote_ciphers = list(
+ filter(
+ server_encrypt_algo_list.__contains__,
+ self.preferred_ciphers,
+ )
+ )
+ if len(agreed_local_ciphers) == 0 or len(agreed_remote_ciphers) == 0:
+ raise IncompatiblePeer(
+ "Incompatible ssh server (no acceptable ciphers)"
+ ) # noqa
+ self.local_cipher = agreed_local_ciphers[0]
+ self.remote_cipher = agreed_remote_ciphers[0]
+ self._log_agreement(
+ "Cipher", local=self.local_cipher, remote=self.remote_cipher
+ )
+
+ if self.server_mode:
+ agreed_remote_macs = list(
+ filter(self.preferred_macs.__contains__, client_mac_algo_list)
+ )
+ agreed_local_macs = list(
+ filter(self.preferred_macs.__contains__, server_mac_algo_list)
+ )
+ else:
+ agreed_local_macs = list(
+ filter(client_mac_algo_list.__contains__, self.preferred_macs)
+ )
+ agreed_remote_macs = list(
+ filter(server_mac_algo_list.__contains__, self.preferred_macs)
+ )
+ if (len(agreed_local_macs) == 0) or (len(agreed_remote_macs) == 0):
+ raise IncompatiblePeer(
+ "Incompatible ssh server (no acceptable macs)"
+ )
+ self.local_mac = agreed_local_macs[0]
+ self.remote_mac = agreed_remote_macs[0]
+ self._log_agreement(
+ "MAC", local=self.local_mac, remote=self.remote_mac
+ )
+
+ if self.server_mode:
+ agreed_remote_compression = list(
+ filter(
+ self.preferred_compression.__contains__,
+ client_compress_algo_list,
+ )
+ )
+ agreed_local_compression = list(
+ filter(
+ self.preferred_compression.__contains__,
+ server_compress_algo_list,
+ )
+ )
+ else:
+ agreed_local_compression = list(
+ filter(
+ client_compress_algo_list.__contains__,
+ self.preferred_compression,
+ )
+ )
+ agreed_remote_compression = list(
+ filter(
+ server_compress_algo_list.__contains__,
+ self.preferred_compression,
+ )
+ )
+ if (
+ len(agreed_local_compression) == 0
+ or len(agreed_remote_compression) == 0
+ ):
+ msg = "Incompatible ssh server (no acceptable compression)"
+ msg += " {!r} {!r} {!r}"
+ raise IncompatiblePeer(
+ msg.format(
+ agreed_local_compression,
+ agreed_remote_compression,
+ self.preferred_compression,
+ )
+ )
+ self.local_compression = agreed_local_compression[0]
+ self.remote_compression = agreed_remote_compression[0]
+ self._log_agreement(
+ "Compression",
+ local=self.local_compression,
+ remote=self.remote_compression,
+ )
+ self._log(DEBUG, "=== End of kex handshake ===")
+
+ # save for computing hash later...
+ # now wait! openssh has a bug (and others might too) where there are
+ # actually some extra bytes (one NUL byte in openssh's case) added to
+ # the end of the packet but not parsed. turns out we need to throw
+ # away those bytes because they aren't part of the hash.
+ self.remote_kex_init = cMSG_KEXINIT + m.get_so_far()
def _activate_inbound(self):
"""switch on newly negotiated encryption parameters for
inbound traffic"""
- pass
+ block_size = self._cipher_info[self.remote_cipher]["block-size"]
+ if self.server_mode:
+ IV_in = self._compute_key("A", block_size)
+ key_in = self._compute_key(
+ "C", self._cipher_info[self.remote_cipher]["key-size"]
+ )
+ else:
+ IV_in = self._compute_key("B", block_size)
+ key_in = self._compute_key(
+ "D", self._cipher_info[self.remote_cipher]["key-size"]
+ )
+ engine = self._get_cipher(
+ self.remote_cipher, key_in, IV_in, self._DECRYPT
+ )
+ etm = "etm@openssh.com" in self.remote_mac
+ mac_size = self._mac_info[self.remote_mac]["size"]
+ mac_engine = self._mac_info[self.remote_mac]["class"]
+ # initial mac keys are done in the hash's natural size (not the
+ # potentially truncated transmission size)
+ if self.server_mode:
+ mac_key = self._compute_key("E", mac_engine().digest_size)
+ else:
+ mac_key = self._compute_key("F", mac_engine().digest_size)
+ self.packetizer.set_inbound_cipher(
+ engine, block_size, mac_engine, mac_size, mac_key, etm=etm
+ )
+ compress_in = self._compression_info[self.remote_compression][1]
+ if compress_in is not None and (
+ self.remote_compression != "zlib@openssh.com" or self.authenticated
+ ):
+ self._log(DEBUG, "Switching on inbound compression ...")
+ self.packetizer.set_inbound_compressor(compress_in())
+ # Reset inbound sequence number if strict mode.
+ if self.agreed_on_strict_kex:
+ self._log(
+ DEBUG,
+ "Resetting inbound seqno after NEWKEYS due to strict mode",
+ )
+ self.packetizer.reset_seqno_in()
def _activate_outbound(self):
"""switch on newly negotiated encryption parameters for
outbound traffic"""
- pass
- _channel_handler_table = {MSG_CHANNEL_SUCCESS: Channel._request_success,
- MSG_CHANNEL_FAILURE: Channel._request_failed, MSG_CHANNEL_DATA:
- Channel._feed, MSG_CHANNEL_EXTENDED_DATA: Channel._feed_extended,
+ m = Message()
+ m.add_byte(cMSG_NEWKEYS)
+ self._send_message(m)
+ # Reset outbound sequence number if strict mode.
+ if self.agreed_on_strict_kex:
+ self._log(
+ DEBUG,
+ "Resetting outbound seqno after NEWKEYS due to strict mode",
+ )
+ self.packetizer.reset_seqno_out()
+ block_size = self._cipher_info[self.local_cipher]["block-size"]
+ if self.server_mode:
+ IV_out = self._compute_key("B", block_size)
+ key_out = self._compute_key(
+ "D", self._cipher_info[self.local_cipher]["key-size"]
+ )
+ else:
+ IV_out = self._compute_key("A", block_size)
+ key_out = self._compute_key(
+ "C", self._cipher_info[self.local_cipher]["key-size"]
+ )
+ engine = self._get_cipher(
+ self.local_cipher, key_out, IV_out, self._ENCRYPT
+ )
+ etm = "etm@openssh.com" in self.local_mac
+ mac_size = self._mac_info[self.local_mac]["size"]
+ mac_engine = self._mac_info[self.local_mac]["class"]
+ # initial mac keys are done in the hash's natural size (not the
+ # potentially truncated transmission size)
+ if self.server_mode:
+ mac_key = self._compute_key("F", mac_engine().digest_size)
+ else:
+ mac_key = self._compute_key("E", mac_engine().digest_size)
+ sdctr = self.local_cipher.endswith("-ctr")
+ self.packetizer.set_outbound_cipher(
+ engine, block_size, mac_engine, mac_size, mac_key, sdctr, etm=etm
+ )
+ compress_out = self._compression_info[self.local_compression][0]
+ if compress_out is not None and (
+ self.local_compression != "zlib@openssh.com" or self.authenticated
+ ):
+ self._log(DEBUG, "Switching on outbound compression ...")
+ self.packetizer.set_outbound_compressor(compress_out())
+ if not self.packetizer.need_rekey():
+ self.in_kex = False
+ # If client indicated extension support, send that packet immediately
+ if (
+ self.server_mode
+ and self.server_sig_algs
+ and self._remote_ext_info == "ext-info-c"
+ ):
+ extensions = {"server-sig-algs": ",".join(self.preferred_pubkeys)}
+ m = Message()
+ m.add_byte(cMSG_EXT_INFO)
+ m.add_int(len(extensions))
+ for name, value in sorted(extensions.items()):
+ m.add_string(name)
+ m.add_string(value)
+ self._send_message(m)
+ # we always expect to receive NEWKEYS now
+ self._expect_packet(MSG_NEWKEYS)
+
+ def _auth_trigger(self):
+ self.authenticated = True
+ # delayed initiation of compression
+ if self.local_compression == "zlib@openssh.com":
+ compress_out = self._compression_info[self.local_compression][0]
+ self._log(DEBUG, "Switching on outbound compression ...")
+ self.packetizer.set_outbound_compressor(compress_out())
+ if self.remote_compression == "zlib@openssh.com":
+ compress_in = self._compression_info[self.remote_compression][1]
+ self._log(DEBUG, "Switching on inbound compression ...")
+ self.packetizer.set_inbound_compressor(compress_in())
+
+ def _parse_ext_info(self, msg):
+ # Packet is a count followed by that many key-string to possibly-bytes
+ # pairs.
+ extensions = {}
+ for _ in range(msg.get_int()):
+ name = msg.get_text()
+ value = msg.get_string()
+ extensions[name] = value
+ self._log(DEBUG, "Got EXT_INFO: {}".format(extensions))
+ # NOTE: this should work ok in cases where a server sends /two/ such
+ # messages; the RFC explicitly states a 2nd one should overwrite the
+ # 1st.
+ self.server_extensions = extensions
+
+ def _parse_newkeys(self, m):
+ self._log(DEBUG, "Switch to new keys ...")
+ self._activate_inbound()
+ # can also free a bunch of stuff here
+ self.local_kex_init = self.remote_kex_init = None
+ self.K = None
+ self.kex_engine = None
+ if self.server_mode and (self.auth_handler is None):
+ # create auth handler for server mode
+ self.auth_handler = AuthHandler(self)
+ if not self.initial_kex_done:
+ # this was the first key exchange
+ # (also signal to packetizer as it sometimes wants to know this
+ # status as well, eg when seqnos rollover)
+ self.initial_kex_done = self.packetizer._initial_kex_done = True
+ # send an event?
+ if self.completion_event is not None:
+ self.completion_event.set()
+ # it's now okay to send data again (if this was a re-key)
+ if not self.packetizer.need_rekey():
+ self.in_kex = False
+ self.clear_to_send_lock.acquire()
+ try:
+ self.clear_to_send.set()
+ finally:
+ self.clear_to_send_lock.release()
+ return
+
+ def _parse_disconnect(self, m):
+ code = m.get_int()
+ desc = m.get_text()
+ self._log(INFO, "Disconnect (code {:d}): {}".format(code, desc))
+
+ def _parse_global_request(self, m):
+ kind = m.get_text()
+ self._log(DEBUG, 'Received global request "{}"'.format(kind))
+ want_reply = m.get_boolean()
+ if not self.server_mode:
+ self._log(
+ DEBUG,
+ 'Rejecting "{}" global request from server.'.format(kind),
+ )
+ ok = False
+ elif kind == "tcpip-forward":
+ address = m.get_text()
+ port = m.get_int()
+ ok = self.server_object.check_port_forward_request(address, port)
+ if ok:
+ ok = (ok,)
+ elif kind == "cancel-tcpip-forward":
+ address = m.get_text()
+ port = m.get_int()
+ self.server_object.cancel_port_forward_request(address, port)
+ ok = True
+ else:
+ ok = self.server_object.check_global_request(kind, m)
+ extra = ()
+ if type(ok) is tuple:
+ extra = ok
+ ok = True
+ if want_reply:
+ msg = Message()
+ if ok:
+ msg.add_byte(cMSG_REQUEST_SUCCESS)
+ msg.add(*extra)
+ else:
+ msg.add_byte(cMSG_REQUEST_FAILURE)
+ self._send_message(msg)
+
+ def _parse_request_success(self, m):
+ self._log(DEBUG, "Global request successful.")
+ self.global_response = m
+ if self.completion_event is not None:
+ self.completion_event.set()
+
+ def _parse_request_failure(self, m):
+ self._log(DEBUG, "Global request denied.")
+ self.global_response = None
+ if self.completion_event is not None:
+ self.completion_event.set()
+
+ def _parse_channel_open_success(self, m):
+ chanid = m.get_int()
+ server_chanid = m.get_int()
+ server_window_size = m.get_int()
+ server_max_packet_size = m.get_int()
+ chan = self._channels.get(chanid)
+ if chan is None:
+ self._log(WARNING, "Success for unrequested channel! [??]")
+ return
+ self.lock.acquire()
+ try:
+ chan._set_remote_channel(
+ server_chanid, server_window_size, server_max_packet_size
+ )
+ self._log(DEBUG, "Secsh channel {:d} opened.".format(chanid))
+ if chanid in self.channel_events:
+ self.channel_events[chanid].set()
+ del self.channel_events[chanid]
+ finally:
+ self.lock.release()
+ return
+
+ def _parse_channel_open_failure(self, m):
+ chanid = m.get_int()
+ reason = m.get_int()
+ reason_str = m.get_text()
+ m.get_text() # ignored language
+ reason_text = CONNECTION_FAILED_CODE.get(reason, "(unknown code)")
+ self._log(
+ ERROR,
+ "Secsh channel {:d} open FAILED: {}: {}".format(
+ chanid, reason_str, reason_text
+ ),
+ )
+ self.lock.acquire()
+ try:
+ self.saved_exception = ChannelException(reason, reason_text)
+ if chanid in self.channel_events:
+ self._channels.delete(chanid)
+ if chanid in self.channel_events:
+ self.channel_events[chanid].set()
+ del self.channel_events[chanid]
+ finally:
+ self.lock.release()
+ return
+
+ def _parse_channel_open(self, m):
+ kind = m.get_text()
+ chanid = m.get_int()
+ initial_window_size = m.get_int()
+ max_packet_size = m.get_int()
+ reject = False
+ if (
+ kind == "auth-agent@openssh.com"
+ and self._forward_agent_handler is not None
+ ):
+ self._log(DEBUG, "Incoming forward agent connection")
+ self.lock.acquire()
+ try:
+ my_chanid = self._next_channel()
+ finally:
+ self.lock.release()
+ elif (kind == "x11") and (self._x11_handler is not None):
+ origin_addr = m.get_text()
+ origin_port = m.get_int()
+ self._log(
+ DEBUG,
+ "Incoming x11 connection from {}:{:d}".format(
+ origin_addr, origin_port
+ ),
+ )
+ self.lock.acquire()
+ try:
+ my_chanid = self._next_channel()
+ finally:
+ self.lock.release()
+ elif (kind == "forwarded-tcpip") and (self._tcp_handler is not None):
+ server_addr = m.get_text()
+ server_port = m.get_int()
+ origin_addr = m.get_text()
+ origin_port = m.get_int()
+ self._log(
+ DEBUG,
+ "Incoming tcp forwarded connection from {}:{:d}".format(
+ origin_addr, origin_port
+ ),
+ )
+ self.lock.acquire()
+ try:
+ my_chanid = self._next_channel()
+ finally:
+ self.lock.release()
+ elif not self.server_mode:
+ self._log(
+ DEBUG,
+ 'Rejecting "{}" channel request from server.'.format(kind),
+ )
+ reject = True
+ reason = OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
+ else:
+ self.lock.acquire()
+ try:
+ my_chanid = self._next_channel()
+ finally:
+ self.lock.release()
+ if kind == "direct-tcpip":
+ # handle direct-tcpip requests coming from the client
+ dest_addr = m.get_text()
+ dest_port = m.get_int()
+ origin_addr = m.get_text()
+ origin_port = m.get_int()
+ reason = self.server_object.check_channel_direct_tcpip_request(
+ my_chanid,
+ (origin_addr, origin_port),
+ (dest_addr, dest_port),
+ )
+ else:
+ reason = self.server_object.check_channel_request(
+ kind, my_chanid
+ )
+ if reason != OPEN_SUCCEEDED:
+ self._log(
+ DEBUG,
+ 'Rejecting "{}" channel request from client.'.format(kind),
+ )
+ reject = True
+ if reject:
+ msg = Message()
+ msg.add_byte(cMSG_CHANNEL_OPEN_FAILURE)
+ msg.add_int(chanid)
+ msg.add_int(reason)
+ msg.add_string("")
+ msg.add_string("en")
+ self._send_message(msg)
+ return
+
+ chan = Channel(my_chanid)
+ self.lock.acquire()
+ try:
+ self._channels.put(my_chanid, chan)
+ self.channels_seen[my_chanid] = True
+ chan._set_transport(self)
+ chan._set_window(
+ self.default_window_size, self.default_max_packet_size
+ )
+ chan._set_remote_channel(
+ chanid, initial_window_size, max_packet_size
+ )
+ finally:
+ self.lock.release()
+ m = Message()
+ m.add_byte(cMSG_CHANNEL_OPEN_SUCCESS)
+ m.add_int(chanid)
+ m.add_int(my_chanid)
+ m.add_int(self.default_window_size)
+ m.add_int(self.default_max_packet_size)
+ self._send_message(m)
+ self._log(
+ DEBUG, "Secsh channel {:d} ({}) opened.".format(my_chanid, kind)
+ )
+ if kind == "auth-agent@openssh.com":
+ self._forward_agent_handler(chan)
+ elif kind == "x11":
+ self._x11_handler(chan, (origin_addr, origin_port))
+ elif kind == "forwarded-tcpip":
+ chan.origin_addr = (origin_addr, origin_port)
+ self._tcp_handler(
+ chan, (origin_addr, origin_port), (server_addr, server_port)
+ )
+ else:
+ self._queue_incoming_channel(chan)
+
+ def _parse_debug(self, m):
+ m.get_boolean() # always_display
+ msg = m.get_string()
+ m.get_string() # language
+ self._log(DEBUG, "Debug msg: {}".format(util.safe_string(msg)))
+
+ def _get_subsystem_handler(self, name):
+ try:
+ self.lock.acquire()
+ if name not in self.subsystem_table:
+ return None, [], {}
+ return self.subsystem_table[name]
+ finally:
+ self.lock.release()
+
+ _channel_handler_table = {
+ MSG_CHANNEL_SUCCESS: Channel._request_success,
+ MSG_CHANNEL_FAILURE: Channel._request_failed,
+ MSG_CHANNEL_DATA: Channel._feed,
+ MSG_CHANNEL_EXTENDED_DATA: Channel._feed_extended,
MSG_CHANNEL_WINDOW_ADJUST: Channel._window_adjust,
- MSG_CHANNEL_REQUEST: Channel._handle_request, MSG_CHANNEL_EOF:
- Channel._handle_eof, MSG_CHANNEL_CLOSE: Channel._handle_close}
+ MSG_CHANNEL_REQUEST: Channel._handle_request,
+ MSG_CHANNEL_EOF: Channel._handle_eof,
+ MSG_CHANNEL_CLOSE: Channel._handle_close,
+ }
+# TODO 4.0: drop this, we barely use it ourselves, it badly replicates the
+# Transport-internal algorithm management, AND does so in a way which doesn't
+# honor newer things like disabled_algorithms!
class SecurityOptions:
"""
Simple object containing the security preferences of an ssh transport.
@@ -1277,7 +3131,8 @@ class SecurityOptions:
``ValueError`` will be raised. If you try to assign something besides a
tuple to one of the fields, ``TypeError`` will be raised.
"""
- __slots__ = '_transport'
+
+ __slots__ = "_transport"
def __init__(self, transport):
self._transport = transport
@@ -1286,40 +3141,102 @@ class SecurityOptions:
"""
Returns a string representation of this object, for debugging.
"""
- return '<paramiko.SecurityOptions for {!r}>'.format(self._transport)
+ return "<paramiko.SecurityOptions for {!r}>".format(self._transport)
+
+ def _set(self, name, orig, x):
+ if type(x) is list:
+ x = tuple(x)
+ if type(x) is not tuple:
+ raise TypeError("expected tuple or list")
+ possible = list(getattr(self._transport, orig).keys())
+ forbidden = [n for n in x if n not in possible]
+ if len(forbidden) > 0:
+ raise ValueError("unknown cipher")
+ setattr(self._transport, name, x)
@property
def ciphers(self):
"""Symmetric encryption ciphers"""
- pass
+ return self._transport._preferred_ciphers
+
+ @ciphers.setter
+ def ciphers(self, x):
+ self._set("_preferred_ciphers", "_cipher_info", x)
@property
def digests(self):
"""Digest (one-way hash) algorithms"""
- pass
+ return self._transport._preferred_macs
+
+ @digests.setter
+ def digests(self, x):
+ self._set("_preferred_macs", "_mac_info", x)
@property
def key_types(self):
"""Public-key algorithms"""
- pass
+ return self._transport._preferred_keys
+
+ @key_types.setter
+ def key_types(self, x):
+ self._set("_preferred_keys", "_key_info", x)
@property
def kex(self):
"""Key exchange algorithms"""
- pass
+ return self._transport._preferred_kex
+
+ @kex.setter
+ def kex(self, x):
+ self._set("_preferred_kex", "_kex_info", x)
@property
def compression(self):
"""Compression algorithms"""
- pass
+ return self._transport._preferred_compression
+ @compression.setter
+ def compression(self, x):
+ self._set("_preferred_compression", "_compression_info", x)
-class ChannelMap:
+class ChannelMap:
def __init__(self):
+ # (id -> Channel)
self._map = weakref.WeakValueDictionary()
self._lock = threading.Lock()
+ def put(self, chanid, chan):
+ self._lock.acquire()
+ try:
+ self._map[chanid] = chan
+ finally:
+ self._lock.release()
+
+ def get(self, chanid):
+ self._lock.acquire()
+ try:
+ return self._map.get(chanid, None)
+ finally:
+ self._lock.release()
+
+ def delete(self, chanid):
+ self._lock.acquire()
+ try:
+ try:
+ del self._map[chanid]
+ except KeyError:
+ pass
+ finally:
+ self._lock.release()
+
+ def values(self):
+ self._lock.acquire()
+ try:
+ return list(self._map.values())
+ finally:
+ self._lock.release()
+
def __len__(self):
self._lock.acquire()
try:
@@ -1335,7 +3252,152 @@ class ServiceRequestingTransport(Transport):
.. versionadded:: 3.2
"""
+ # NOTE: this purposefully duplicates some of the parent class in order to
+ # modernize, refactor, etc. The intent is that eventually we will collapse
+ # this one onto the parent in a backwards incompatible release.
+
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._service_userauth_accepted = False
self._handler_table[MSG_SERVICE_ACCEPT] = self._parse_service_accept
+
+ def _parse_service_accept(self, m):
+ service = m.get_text()
+ # Short-circuit for any service name not ssh-userauth.
+ # NOTE: it's technically possible for 'service name' in
+ # SERVICE_REQUEST/ACCEPT messages to be "ssh-connection" --
+ # but I don't see evidence of Paramiko ever initiating or expecting to
+ # receive one of these. We /do/ see the 'service name' field in
+ # MSG_USERAUTH_REQUEST/ACCEPT/FAILURE set to this string, but that is a
+ # different set of handlers, so...!
+ if service != "ssh-userauth":
+ # TODO 4.0: consider erroring here (with an ability to opt out?)
+ # instead as it probably means something went Very Wrong.
+ self._log(
+ DEBUG, 'Service request "{}" accepted (?)'.format(service)
+ )
+ return
+ # Record that we saw a service-userauth acceptance, meaning we are free
+ # to submit auth requests.
+ self._service_userauth_accepted = True
+ self._log(DEBUG, "MSG_SERVICE_ACCEPT received; auth may begin")
+
+ def ensure_session(self):
+ # Make sure we're not trying to auth on a not-yet-open or
+ # already-closed transport session; that's our responsibility, not that
+ # of AuthHandler.
+ if (not self.active) or (not self.initial_kex_done):
+ # TODO: better error message? this can happen in many places, eg
+ # user error (authing before connecting) or developer error (some
+ # improperly handled pre/mid auth shutdown didn't become fatal
+ # enough). The latter is much more common & should ideally be fixed
+ # by terminating things harder?
+ raise SSHException("No existing session")
+ # Also make sure we've actually been told we are allowed to auth.
+ if self._service_userauth_accepted:
+ return
+ # Or request to do so, otherwise.
+ m = Message()
+ m.add_byte(cMSG_SERVICE_REQUEST)
+ m.add_string("ssh-userauth")
+ self._log(DEBUG, "Sending MSG_SERVICE_REQUEST: ssh-userauth")
+ self._send_message(m)
+ # Now we wait to hear back; the user is expecting a blocking-style auth
+ # request so there's no point giving control back anywhere.
+ while not self._service_userauth_accepted:
+ # TODO: feels like we're missing an AuthHandler Event like
+ # 'self.auth_event' which is set when AuthHandler shuts down in
+ # ways good AND bad. Transport only seems to have completion_event
+ # which is unclear re: intent, eg it's set by newkeys which always
+ # happens on connection, so it'll always be set by the time we get
+ # here.
+ # NOTE: this copies the timing of event.wait() in
+ # AuthHandler.wait_for_response, re: 1/10 of a second. Could
+ # presumably be smaller, but seems unlikely this period is going to
+ # be "too long" for any code doing ssh networking...
+ time.sleep(0.1)
+ self.auth_handler = self.get_auth_handler()
+
+ def get_auth_handler(self):
+ # NOTE: using new sibling subclass instead of classic AuthHandler
+ return AuthOnlyHandler(self)
+
+ def auth_none(self, username):
+ # TODO 4.0: merge to parent, preserving (most of) docstring
+ self.ensure_session()
+ return self.auth_handler.auth_none(username)
+
+ def auth_password(self, username, password, fallback=True):
+ # TODO 4.0: merge to parent, preserving (most of) docstring
+ self.ensure_session()
+ try:
+ return self.auth_handler.auth_password(username, password)
+ except BadAuthenticationType as e:
+ # if password auth isn't allowed, but keyboard-interactive *is*,
+ # try to fudge it
+ if not fallback or ("keyboard-interactive" not in e.allowed_types):
+ raise
+ try:
+
+ def handler(title, instructions, fields):
+ if len(fields) > 1:
+ raise SSHException("Fallback authentication failed.")
+ if len(fields) == 0:
+ # for some reason, at least on os x, a 2nd request will
+ # be made with zero fields requested. maybe it's just
+ # to try to fake out automated scripting of the exact
+ # type we're doing here. *shrug* :)
+ return []
+ return [password]
+
+ return self.auth_interactive(username, handler)
+ except SSHException:
+ # attempt to fudge failed; just raise the original exception
+ raise e
+
+ def auth_publickey(self, username, key):
+ # TODO 4.0: merge to parent, preserving (most of) docstring
+ self.ensure_session()
+ return self.auth_handler.auth_publickey(username, key)
+
+ def auth_interactive(self, username, handler, submethods=""):
+ # TODO 4.0: merge to parent, preserving (most of) docstring
+ self.ensure_session()
+ return self.auth_handler.auth_interactive(
+ username, handler, submethods
+ )
+
+ def auth_interactive_dumb(self, username, handler=None, submethods=""):
+ # TODO 4.0: merge to parent, preserving (most of) docstring
+ # NOTE: legacy impl omitted equiv of ensure_session since it just wraps
+ # another call to an auth method. however we reinstate it for
+ # consistency reasons.
+ self.ensure_session()
+ if not handler:
+
+ def handler(title, instructions, prompt_list):
+ answers = []
+ if title:
+ print(title.strip())
+ if instructions:
+ print(instructions.strip())
+ for prompt, show_input in prompt_list:
+ print(prompt.strip(), end=" ")
+ answers.append(input())
+ return answers
+
+ return self.auth_interactive(username, handler, submethods)
+
+ def auth_gssapi_with_mic(self, username, gss_host, gss_deleg_creds):
+ # TODO 4.0: merge to parent, preserving (most of) docstring
+ self.ensure_session()
+ self.auth_handler = self.get_auth_handler()
+ return self.auth_handler.auth_gssapi_with_mic(
+ username, gss_host, gss_deleg_creds
+ )
+
+ def auth_gssapi_keyex(self, username):
+ # TODO 4.0: merge to parent, preserving (most of) docstring
+ self.ensure_session()
+ self.auth_handler = self.get_auth_handler()
+ return self.auth_handler.auth_gssapi_keyex(username)
diff --git a/paramiko/util.py b/paramiko/util.py
index d9df7198..f1e33a50 100644
--- a/paramiko/util.py
+++ b/paramiko/util.py
@@ -1,25 +1,142 @@
+# Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
"""
Useful functions used by the rest of paramiko.
"""
+
+
import sys
import struct
import traceback
import threading
import logging
-from paramiko.common import DEBUG, zero_byte, xffffffff, max_byte, byte_ord, byte_chr
+
+from paramiko.common import (
+ DEBUG,
+ zero_byte,
+ xffffffff,
+ max_byte,
+ byte_ord,
+ byte_chr,
+)
from paramiko.config import SSHConfig
def inflate_long(s, always_positive=False):
"""turns a normalized byte string into a long-int
(adapted from Crypto.Util.number)"""
- pass
+ out = 0
+ negative = 0
+ if not always_positive and (len(s) > 0) and (byte_ord(s[0]) >= 0x80):
+ negative = 1
+ if len(s) % 4:
+ filler = zero_byte
+ if negative:
+ filler = max_byte
+ # never convert this to ``s +=`` because this is a string, not a number
+ # noinspection PyAugmentAssignment
+ s = filler * (4 - len(s) % 4) + s
+ for i in range(0, len(s), 4):
+ out = (out << 32) + struct.unpack(">I", s[i : i + 4])[0]
+ if negative:
+ out -= 1 << (8 * len(s))
+ return out
def deflate_long(n, add_sign_padding=True):
"""turns a long-int into a normalized byte string
(adapted from Crypto.Util.number)"""
- pass
+ # after much testing, this algorithm was deemed to be the fastest
+ s = bytes()
+ n = int(n)
+ while (n != 0) and (n != -1):
+ s = struct.pack(">I", n & xffffffff) + s
+ n >>= 32
+ # strip off leading zeros, FFs
+ for i in enumerate(s):
+ if (n == 0) and (i[1] != 0):
+ break
+ if (n == -1) and (i[1] != 0xFF):
+ break
+ else:
+ # degenerate case, n was either 0 or -1
+ i = (0,)
+ if n == 0:
+ s = zero_byte
+ else:
+ s = max_byte
+ s = s[i[0] :]
+ if add_sign_padding:
+ if (n == 0) and (byte_ord(s[0]) >= 0x80):
+ s = zero_byte + s
+ if (n == -1) and (byte_ord(s[0]) < 0x80):
+ s = max_byte + s
+ return s
+
+
+def format_binary(data, prefix=""):
+ x = 0
+ out = []
+ while len(data) > x + 16:
+ out.append(format_binary_line(data[x : x + 16]))
+ x += 16
+ if x < len(data):
+ out.append(format_binary_line(data[x:]))
+ return [prefix + line for line in out]
+
+
+def format_binary_line(data):
+ left = " ".join(["{:02X}".format(byte_ord(c)) for c in data])
+ right = "".join(
+ [".{:c}..".format(byte_ord(c))[(byte_ord(c) + 63) // 95] for c in data]
+ )
+ return "{:50s} {}".format(left, right)
+
+
+def safe_string(s):
+ out = b""
+ for c in s:
+ i = byte_ord(c)
+ if 32 <= i <= 127:
+ out += byte_chr(i)
+ else:
+ out += b("%{:02X}".format(i))
+ return out
+
+
+def bit_length(n):
+ try:
+ return n.bit_length()
+ except AttributeError:
+ norm = deflate_long(n, False)
+ hbyte = byte_ord(norm[0])
+ if hbyte == 0:
+ return 1
+ bitlen = len(norm) * 8
+ while not (hbyte & 0x80):
+ hbyte <<= 1
+ bitlen -= 1
+ return bitlen
+
+
+def tb_strings():
+ return "".join(traceback.format_exception(*sys.exc_info())).split("\n")
def generate_key_bytes(hash_alg, salt, key, nbytes):
@@ -36,7 +153,21 @@ def generate_key_bytes(hash_alg, salt, key, nbytes):
:param int nbytes: number of bytes to generate.
:return: Key data, as `bytes`.
"""
- pass
+ keydata = bytes()
+ digest = bytes()
+ if len(salt) > 8:
+ salt = salt[:8]
+ while nbytes > 0:
+ hash_obj = hash_alg()
+ if len(digest) > 0:
+ hash_obj.update(digest)
+ hash_obj.update(b(key))
+ hash_obj.update(salt)
+ digest = hash_obj.digest()
+ size = min(nbytes, len(digest))
+ keydata += digest[:size]
+ nbytes -= size
+ return keydata
def load_host_keys(filename):
@@ -55,7 +186,9 @@ def load_host_keys(filename):
:return:
nested dict of `.PKey` objects, indexed by hostname and then keytype
"""
- pass
+ from paramiko.hostkeys import HostKeys
+
+ return HostKeys(filename)
def parse_ssh_config(file_obj):
@@ -65,14 +198,31 @@ def parse_ssh_config(file_obj):
.. deprecated:: 2.7
Use `SSHConfig.from_file` instead.
"""
- pass
+ config = SSHConfig()
+ config.parse(file_obj)
+ return config
def lookup_ssh_host_config(hostname, config):
"""
Provided only as a backward-compatible wrapper around `.SSHConfig`.
"""
- pass
+ return config.lookup(hostname)
+
+
+def mod_inverse(x, m):
+ # it's crazy how small Python can make this function.
+ u1, u2, u3 = 1, 0, m
+ v1, v2, v3 = 0, 1, x
+
+ while v3 > 0:
+ q = u3 // v3
+ u1, v1 = v1, u1 - v1 * q
+ u2, v2 = v2, u2 - v2 * q
+ u3, v3 = v3, u3 - v3 * q
+ if u2 < 0:
+ u2 += m
+ return u2
_g_thread_data = threading.local()
@@ -80,21 +230,59 @@ _g_thread_counter = 0
_g_thread_lock = threading.Lock()
+def get_thread_id():
+ global _g_thread_data, _g_thread_counter, _g_thread_lock
+ try:
+ return _g_thread_data.id
+ except AttributeError:
+ with _g_thread_lock:
+ _g_thread_counter += 1
+ _g_thread_data.id = _g_thread_counter
+ return _g_thread_data.id
+
+
def log_to_file(filename, level=DEBUG):
"""send paramiko logs to a logfile,
if they're not already going somewhere"""
- pass
-
-
+ logger = logging.getLogger("paramiko")
+ if len(logger.handlers) > 0:
+ return
+ logger.setLevel(level)
+ f = open(filename, "a")
+ handler = logging.StreamHandler(f)
+ frm = "%(levelname)-.3s [%(asctime)s.%(msecs)03d] thr=%(_threadid)-3d"
+ frm += " %(name)s: %(message)s"
+ handler.setFormatter(logging.Formatter(frm, "%Y%m%d-%H:%M:%S"))
+ logger.addHandler(handler)
+
+
+# make only one filter object, so it doesn't get applied more than once
class PFilter:
- pass
+ def filter(self, record):
+ record._threadid = get_thread_id()
+ return True
_pfilter = PFilter()
-class ClosingContextManager:
+def get_logger(name):
+ logger = logging.getLogger(name)
+ logger.addFilter(_pfilter)
+ return logger
+
+
+def constant_time_bytes_eq(a, b):
+ if len(a) != len(b):
+ return False
+ res = 0
+ # noinspection PyUnresolvedReferences
+ for i in range(len(a)): # noqa: F821
+ res |= byte_ord(a[i]) ^ byte_ord(b[i])
+ return res == 0
+
+class ClosingContextManager:
def __enter__(self):
return self
@@ -102,18 +290,48 @@ class ClosingContextManager:
self.close()
+def clamp_value(minimum, val, maximum):
+ return max(minimum, min(val, maximum))
+
+
def asbytes(s):
"""
Coerce to bytes if possible or return unchanged.
"""
- pass
-
-
-def b(s, encoding='utf8'):
+ try:
+ # Attempt to run through our version of b(), which does the Right Thing
+ # for unicode strings vs bytestrings, and raises TypeError if it's not
+ # one of those types.
+ return b(s)
+ except TypeError:
+ try:
+ # If it wasn't a string/byte/buffer-ish object, try calling an
+ # asbytes() method, which many of our internal classes implement.
+ return s.asbytes()
+ except AttributeError:
+ # Finally, just do nothing & assume this object is sufficiently
+ # byte-y or buffer-y that everything will work out (or that callers
+ # are capable of handling whatever it is.)
+ return s
+
+
+# TODO: clean this up / force callers to assume bytes OR unicode
+def b(s, encoding="utf8"):
"""cast unicode or bytes to bytes"""
- pass
+ if isinstance(s, bytes):
+ return s
+ elif isinstance(s, str):
+ return s.encode(encoding)
+ else:
+ raise TypeError(f"Expected unicode or bytes, got {type(s)}")
-def u(s, encoding='utf8'):
+# TODO: clean this up / force callers to assume bytes OR unicode
+def u(s, encoding="utf8"):
"""cast bytes or unicode to unicode"""
- pass
+ if isinstance(s, bytes):
+ return s.decode(encoding)
+ elif isinstance(s, str):
+ return s
+ else:
+ raise TypeError(f"Expected unicode or bytes, got {type(s)}")
diff --git a/paramiko/win_openssh.py b/paramiko/win_openssh.py
index ac5aeeee..614b5898 100644
--- a/paramiko/win_openssh.py
+++ b/paramiko/win_openssh.py
@@ -1,17 +1,56 @@
+# Copyright (C) 2021 Lew Gordon <lew.gordon@genesys.com>
+# Copyright (C) 2022 Patrick Spendrin <ps_ml@gmx.de>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
import os.path
import time
-PIPE_NAME = '\\\\.\\pipe\\openssh-ssh-agent'
+PIPE_NAME = r"\\.\pipe\openssh-ssh-agent"
+
+
+def can_talk_to_agent():
+ # use os.listdir() instead of os.path.exists(), because os.path.exists()
+ # uses CreateFileW() API and the pipe cannot be reopen unless the server
+ # calls DisconnectNamedPipe().
+ dir_, name = os.path.split(PIPE_NAME)
+ name = name.lower()
+ return any(name == n.lower() for n in os.listdir(dir_))
-class OpenSSHAgentConnection:
+class OpenSSHAgentConnection:
def __init__(self):
while True:
try:
self._pipe = os.open(PIPE_NAME, os.O_RDWR | os.O_BINARY)
except OSError as e:
+ # retry when errno 22 which means that the server has not
+ # called DisconnectNamedPipe() yet.
if e.errno != 22:
raise
else:
break
time.sleep(0.1)
+
+ def send(self, data):
+ return os.write(self._pipe, data)
+
+ def recv(self, n):
+ return os.read(self._pipe, n)
+
+ def close(self):
+ return os.close(self._pipe)
diff --git a/paramiko/win_pageant.py b/paramiko/win_pageant.py
index 2bad5392..c927de65 100644
--- a/paramiko/win_pageant.py
+++ b/paramiko/win_pageant.py
@@ -1,19 +1,49 @@
+# Copyright (C) 2005 John Arbash-Meinel <john@arbash-meinel.com>
+# Modified up by: Todd Whiteman <ToddW@ActiveState.com>
+#
+# This file is part of paramiko.
+#
+# Paramiko is free software; you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation; either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
"""
Functions for communicating with Pageant, the basic windows ssh agent program.
"""
+
import array
import ctypes.wintypes
import platform
import struct
from paramiko.common import zero_byte
from paramiko.util import b
+
import _thread as thread
+
from . import _winapi
-_AGENT_COPYDATA_ID = 2152616122
+
+
+_AGENT_COPYDATA_ID = 0x804E50BA
_AGENT_MAX_MSGLEN = 8192
+# Note: The WM_COPYDATA value is pulled from win32con, as a workaround
+# so we do not have to import this huge library just for this one variable.
win32con_WM_COPYDATA = 74
+def _get_pageant_window_object():
+ return ctypes.windll.user32.FindWindowA(b"Pageant", b"Pageant")
+
+
def can_talk_to_agent():
"""
Check to see if there is a "Pageant" agent we can talk to.
@@ -21,10 +51,10 @@ def can_talk_to_agent():
This checks both if we have the required libraries (win32all or ctypes)
and if there is a Pageant currently running.
"""
- pass
+ return bool(_get_pageant_window_object())
-if platform.architecture()[0] == '64bit':
+if platform.architecture()[0] == "64bit":
ULONG_PTR = ctypes.c_uint64
else:
ULONG_PTR = ctypes.c_uint32
@@ -35,8 +65,12 @@ class COPYDATASTRUCT(ctypes.Structure):
ctypes implementation of
http://msdn.microsoft.com/en-us/library/windows/desktop/ms649010%28v=vs.85%29.aspx
"""
- _fields_ = [('num_data', ULONG_PTR), ('data_size', ctypes.wintypes.
- DWORD), ('data_loc', ctypes.c_void_p)]
+
+ _fields_ = [
+ ("num_data", ULONG_PTR),
+ ("data_size", ctypes.wintypes.DWORD),
+ ("data_loc", ctypes.c_void_p),
+ ]
def _query_pageant(msg):
@@ -44,7 +78,37 @@ def _query_pageant(msg):
Communication with the Pageant process is done through a shared
memory-mapped file.
"""
- pass
+ hwnd = _get_pageant_window_object()
+ if not hwnd:
+ # Raise a failure to connect exception, pageant isn't running anymore!
+ return None
+
+ # create a name for the mmap
+ map_name = f"PageantRequest{thread.get_ident():08x}"
+
+ pymap = _winapi.MemoryMap(
+ map_name, _AGENT_MAX_MSGLEN, _winapi.get_security_attributes_for_user()
+ )
+ with pymap:
+ pymap.write(msg)
+ # Create an array buffer containing the mapped filename
+ char_buffer = array.array("b", b(map_name) + zero_byte) # noqa
+ char_buffer_address, char_buffer_size = char_buffer.buffer_info()
+ # Create a string to use for the SendMessage function call
+ cds = COPYDATASTRUCT(
+ _AGENT_COPYDATA_ID, char_buffer_size, char_buffer_address
+ )
+
+ response = ctypes.windll.user32.SendMessageA(
+ hwnd, win32con_WM_COPYDATA, ctypes.sizeof(cds), ctypes.byref(cds)
+ )
+
+ if response > 0:
+ pymap.seek(0)
+ datalen = pymap.read(4)
+ retlen = struct.unpack(">I", datalen)[0]
+ return datalen + pymap.read(retlen)
+ return None
class PageantConnection:
@@ -57,3 +121,18 @@ class PageantConnection:
def __init__(self):
self._response = None
+
+ def send(self, data):
+ self._response = _query_pageant(data)
+
+ def recv(self, n):
+ if self._response is None:
+ return ""
+ ret = self._response[:n]
+ self._response = self._response[n:]
+ if self._response == "":
+ self._response = None
+ return ret
+
+ def close(self):
+ pass