back to Reference (Gold) summary
Reference (Gold): cookiecutter
Pytest Summary for test tests
status | count |
---|---|
passed | 367 |
skipped | 4 |
total | 371 |
collected | 371 |
Failed pytests:
Patch diff
diff --git a/cookiecutter/cli.py b/cookiecutter/cli.py
index b050655..8b67863 100644
--- a/cookiecutter/cli.py
+++ b/cookiecutter/cli.py
@@ -1,81 +1,241 @@
"""Main `cookiecutter` CLI."""
+
import collections
import json
import os
import sys
+
import click
+
from cookiecutter import __version__
from cookiecutter.config import get_user_config
-from cookiecutter.exceptions import ContextDecodingException, FailedHookException, InvalidModeException, InvalidZipRepository, OutputDirExistsException, RepositoryCloneFailed, RepositoryNotFound, UndefinedVariableInTemplate, UnknownExtension
+from cookiecutter.exceptions import (
+ ContextDecodingException,
+ FailedHookException,
+ InvalidModeException,
+ InvalidZipRepository,
+ OutputDirExistsException,
+ RepositoryCloneFailed,
+ RepositoryNotFound,
+ UndefinedVariableInTemplate,
+ UnknownExtension,
+)
from cookiecutter.log import configure_logger
from cookiecutter.main import cookiecutter
def version_msg():
"""Return the Cookiecutter version, location and Python powering it."""
- pass
+ python_version = sys.version
+ location = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+ return f"Cookiecutter {__version__} from {location} (Python {python_version})"
def validate_extra_context(ctx, param, value):
"""Validate extra context."""
- pass
+ for string in value:
+ if '=' not in string:
+ raise click.BadParameter(
+ f"EXTRA_CONTEXT should contain items of the form key=value; "
+ f"'{string}' doesn't match that form"
+ )
+
+ # Convert tuple -- e.g.: ('program_name=foobar', 'startsecs=66')
+ # to dict -- e.g.: {'program_name': 'foobar', 'startsecs': '66'}
+ return collections.OrderedDict(s.split('=', 1) for s in value) or None
def list_installed_templates(default_config, passed_config_file):
"""List installed (locally cloned) templates. Use cookiecutter --list-installed."""
- pass
+ config = get_user_config(passed_config_file, default_config)
+ cookiecutter_folder = config.get('cookiecutters_dir')
+ if not os.path.exists(cookiecutter_folder):
+ click.echo(
+ f"Error: Cannot list installed templates. "
+ f"Folder does not exist: {cookiecutter_folder}"
+ )
+ sys.exit(-1)
+
+ template_names = [
+ folder
+ for folder in os.listdir(cookiecutter_folder)
+ if os.path.exists(
+ os.path.join(cookiecutter_folder, folder, 'cookiecutter.json')
+ )
+ ]
+ click.echo(f'{len(template_names)} installed templates: ')
+ for name in template_names:
+ click.echo(f' * {name}')
@click.command(context_settings=dict(help_option_names=['-h', '--help']))
@click.version_option(__version__, '-V', '--version', message=version_msg())
@click.argument('template', required=False)
@click.argument('extra_context', nargs=-1, callback=validate_extra_context)
-@click.option('--no-input', is_flag=True, help=
- 'Do not prompt for parameters and only use cookiecutter.json file content. Defaults to deleting any cached resources and redownloading them. Cannot be combined with the --replay flag.'
- )
-@click.option('-c', '--checkout', help=
- 'branch, tag or commit to checkout after git clone')
-@click.option('--directory', help=
- 'Directory within repo that holds cookiecutter.json file for advanced repositories with multi templates in it'
- )
-@click.option('-v', '--verbose', is_flag=True, help=
- 'Print debug information', default=False)
-@click.option('--replay', is_flag=True, help=
- 'Do not prompt for parameters and only use information entered previously. Cannot be combined with the --no-input flag or with extra configuration passed.'
- )
-@click.option('--replay-file', type=click.Path(), default=None, help=
- 'Use this file for replay instead of the default.')
-@click.option('-f', '--overwrite-if-exists', is_flag=True, help=
- 'Overwrite the contents of the output directory if it already exists')
-@click.option('-s', '--skip-if-file-exists', is_flag=True, help=
- 'Skip the files in the corresponding directories if they already exist',
- default=False)
-@click.option('-o', '--output-dir', default='.', type=click.Path(), help=
- 'Where to output the generated project dir into')
-@click.option('--config-file', type=click.Path(), default=None, help=
- 'User configuration file')
-@click.option('--default-config', is_flag=True, help=
- 'Do not load a config file. Use the defaults instead')
-@click.option('--debug-file', type=click.Path(), default=None, help=
- 'File to be used as a stream for DEBUG logging')
-@click.option('--accept-hooks', type=click.Choice(['yes', 'ask', 'no']),
- default='yes', help='Accept pre/post hooks')
-@click.option('-l', '--list-installed', is_flag=True, help=
- 'List currently installed templates.')
-@click.option('--keep-project-on-failure', is_flag=True, help=
- 'Do not delete project folder on failure')
-def main(template, extra_context, no_input, checkout, verbose, replay,
- overwrite_if_exists, output_dir, config_file, default_config,
- debug_file, directory, skip_if_file_exists, accept_hooks, replay_file,
- list_installed, keep_project_on_failure):
+@click.option(
+ '--no-input',
+ is_flag=True,
+ help='Do not prompt for parameters and only use cookiecutter.json file content. '
+ 'Defaults to deleting any cached resources and redownloading them. '
+ 'Cannot be combined with the --replay flag.',
+)
+@click.option(
+ '-c',
+ '--checkout',
+ help='branch, tag or commit to checkout after git clone',
+)
+@click.option(
+ '--directory',
+ help='Directory within repo that holds cookiecutter.json file '
+ 'for advanced repositories with multi templates in it',
+)
+@click.option(
+ '-v', '--verbose', is_flag=True, help='Print debug information', default=False
+)
+@click.option(
+ '--replay',
+ is_flag=True,
+ help='Do not prompt for parameters and only use information entered previously. '
+ 'Cannot be combined with the --no-input flag or with extra configuration passed.',
+)
+@click.option(
+ '--replay-file',
+ type=click.Path(),
+ default=None,
+ help='Use this file for replay instead of the default.',
+)
+@click.option(
+ '-f',
+ '--overwrite-if-exists',
+ is_flag=True,
+ help='Overwrite the contents of the output directory if it already exists',
+)
+@click.option(
+ '-s',
+ '--skip-if-file-exists',
+ is_flag=True,
+ help='Skip the files in the corresponding directories if they already exist',
+ default=False,
+)
+@click.option(
+ '-o',
+ '--output-dir',
+ default='.',
+ type=click.Path(),
+ help='Where to output the generated project dir into',
+)
+@click.option(
+ '--config-file', type=click.Path(), default=None, help='User configuration file'
+)
+@click.option(
+ '--default-config',
+ is_flag=True,
+ help='Do not load a config file. Use the defaults instead',
+)
+@click.option(
+ '--debug-file',
+ type=click.Path(),
+ default=None,
+ help='File to be used as a stream for DEBUG logging',
+)
+@click.option(
+ '--accept-hooks',
+ type=click.Choice(['yes', 'ask', 'no']),
+ default='yes',
+ help='Accept pre/post hooks',
+)
+@click.option(
+ '-l', '--list-installed', is_flag=True, help='List currently installed templates.'
+)
+@click.option(
+ '--keep-project-on-failure',
+ is_flag=True,
+ help='Do not delete project folder on failure',
+)
+def main(
+ template,
+ extra_context,
+ no_input,
+ checkout,
+ verbose,
+ replay,
+ overwrite_if_exists,
+ output_dir,
+ config_file,
+ default_config,
+ debug_file,
+ directory,
+ skip_if_file_exists,
+ accept_hooks,
+ replay_file,
+ list_installed,
+ keep_project_on_failure,
+):
"""Create a project from a Cookiecutter project template (TEMPLATE).
Cookiecutter is free and open source software, developed and managed by
volunteers. If you would like to help out or fund the project, please get
in touch at https://github.com/cookiecutter/cookiecutter.
"""
- pass
+ # Commands that should work without arguments
+ if list_installed:
+ list_installed_templates(default_config, config_file)
+ sys.exit(0)
+
+ # Raising usage, after all commands that should work without args.
+ if not template or template.lower() == 'help':
+ click.echo(click.get_current_context().get_help())
+ sys.exit(0)
+
+ configure_logger(stream_level='DEBUG' if verbose else 'INFO', debug_file=debug_file)
+
+ # If needed, prompt the user to ask whether or not they want to execute
+ # the pre/post hooks.
+ if accept_hooks == "ask":
+ _accept_hooks = click.confirm("Do you want to execute hooks?")
+ else:
+ _accept_hooks = accept_hooks == "yes"
+
+ if replay_file:
+ replay = replay_file
+
+ try:
+ cookiecutter(
+ template,
+ checkout,
+ no_input,
+ extra_context=extra_context,
+ replay=replay,
+ overwrite_if_exists=overwrite_if_exists,
+ output_dir=output_dir,
+ config_file=config_file,
+ default_config=default_config,
+ password=os.environ.get('COOKIECUTTER_REPO_PASSWORD'),
+ directory=directory,
+ skip_if_file_exists=skip_if_file_exists,
+ accept_hooks=_accept_hooks,
+ keep_project_on_failure=keep_project_on_failure,
+ )
+ except (
+ ContextDecodingException,
+ OutputDirExistsException,
+ InvalidModeException,
+ FailedHookException,
+ UnknownExtension,
+ InvalidZipRepository,
+ RepositoryNotFound,
+ RepositoryCloneFailed,
+ ) as e:
+ click.echo(e)
+ sys.exit(1)
+ except UndefinedVariableInTemplate as undefined_err:
+ click.echo(f'{undefined_err.message}')
+ click.echo(f'Error message: {undefined_err.error.message}')
+
+ context_str = json.dumps(undefined_err.context, indent=4, sort_keys=True)
+ click.echo(f'Context: {context_str}')
+ sys.exit(1)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/cookiecutter/config.py b/cookiecutter/config.py
index 6356215..04d59b7 100644
--- a/cookiecutter/config.py
+++ b/cookiecutter/config.py
@@ -1,23 +1,37 @@
"""Global configuration handling."""
+
import collections
import copy
import logging
import os
+
import yaml
+
from cookiecutter.exceptions import ConfigDoesNotExistException, InvalidConfiguration
+
logger = logging.getLogger(__name__)
+
USER_CONFIG_PATH = os.path.expanduser('~/.cookiecutterrc')
-BUILTIN_ABBREVIATIONS = {'gh': 'https://github.com/{0}.git', 'gl':
- 'https://gitlab.com/{0}.git', 'bb': 'https://bitbucket.org/{0}'}
-DEFAULT_CONFIG = {'cookiecutters_dir': os.path.expanduser(
- '~/.cookiecutters/'), 'replay_dir': os.path.expanduser(
- '~/.cookiecutter_replay/'), 'default_context': collections.OrderedDict(
- []), 'abbreviations': BUILTIN_ABBREVIATIONS}
+
+BUILTIN_ABBREVIATIONS = {
+ 'gh': 'https://github.com/{0}.git',
+ 'gl': 'https://gitlab.com/{0}.git',
+ 'bb': 'https://bitbucket.org/{0}',
+}
+
+DEFAULT_CONFIG = {
+ 'cookiecutters_dir': os.path.expanduser('~/.cookiecutters/'),
+ 'replay_dir': os.path.expanduser('~/.cookiecutter_replay/'),
+ 'default_context': collections.OrderedDict([]),
+ 'abbreviations': BUILTIN_ABBREVIATIONS,
+}
def _expand_path(path):
"""Expand both environment variables and user home in the given path."""
- pass
+ path = os.path.expandvars(path)
+ path = os.path.expanduser(path)
+ return path
def merge_configs(default, overwrite):
@@ -26,12 +40,46 @@ def merge_configs(default, overwrite):
Dict values that are dictionaries themselves will be updated, whilst
preserving existing keys.
"""
- pass
+ new_config = copy.deepcopy(default)
+
+ for k, v in overwrite.items():
+ # Make sure to preserve existing items in
+ # nested dicts, for example `abbreviations`
+ if isinstance(v, dict):
+ new_config[k] = merge_configs(default.get(k, {}), v)
+ else:
+ new_config[k] = v
+
+ return new_config
def get_config(config_path):
"""Retrieve the config from the specified path, returning a config dict."""
- pass
+ if not os.path.exists(config_path):
+ raise ConfigDoesNotExistException(f'Config file {config_path} does not exist.')
+
+ logger.debug('config_path is %s', config_path)
+ with open(config_path, encoding='utf-8') as file_handle:
+ try:
+ yaml_dict = yaml.safe_load(file_handle) or {}
+ except yaml.YAMLError as e:
+ raise InvalidConfiguration(
+ f'Unable to parse YAML file {config_path}.'
+ ) from e
+ if not isinstance(yaml_dict, dict):
+ raise InvalidConfiguration(
+ f'Top-level element of YAML file {config_path} should be an object.'
+ )
+
+ config_dict = merge_configs(DEFAULT_CONFIG, yaml_dict)
+
+ raw_replay_dir = config_dict['replay_dir']
+ config_dict['replay_dir'] = _expand_path(raw_replay_dir)
+
+ raw_cookies_dir = config_dict['cookiecutters_dir']
+ config_dict['cookiecutters_dir'] = _expand_path(raw_cookies_dir)
+
+ return config_dict
def get_user_config(config_file=None, default_config=False):
@@ -53,4 +101,34 @@ def get_user_config(config_file=None, default_config=False):
If the environment variable is not set, try the default config file path
before falling back to the default config values.
"""
- pass
+ # Do NOT load a config. Merge provided values with defaults and return them instead
+ if default_config and isinstance(default_config, dict):
+ return merge_configs(DEFAULT_CONFIG, default_config)
+
+ # Do NOT load a config. Return defaults instead.
+ if default_config:
+ logger.debug("Force ignoring user config with default_config switch.")
+ return copy.copy(DEFAULT_CONFIG)
+
+ # Load the given config file
+ if config_file and config_file is not USER_CONFIG_PATH:
+ logger.debug("Loading custom config from %s.", config_file)
+ return get_config(config_file)
+
+ try:
+ # Does the user set up a config environment variable?
+ env_config_file = os.environ['COOKIECUTTER_CONFIG']
+ except KeyError:
+ # Load an optional user config if it exists
+ # otherwise return the defaults
+ if os.path.exists(USER_CONFIG_PATH):
+ logger.debug("Loading config from %s.", USER_CONFIG_PATH)
+ return get_config(USER_CONFIG_PATH)
+ else:
+ logger.debug("User config not found. Loading default config.")
+ return copy.copy(DEFAULT_CONFIG)
+ else:
+ # There is a config environment variable. Try to load it.
+ # Do not check for existence, so invalid file paths raise an error.
+ logger.debug("User config not found or not specified. Loading default config.")
+ return get_config(env_config_file)
diff --git a/cookiecutter/environment.py b/cookiecutter/environment.py
index 8a7bb61..235f74b 100644
--- a/cookiecutter/environment.py
+++ b/cookiecutter/environment.py
@@ -1,5 +1,7 @@
"""Jinja2 environment and extensions loading."""
+
from jinja2 import Environment, StrictUndefined
+
from cookiecutter.exceptions import UnknownExtension
@@ -20,12 +22,16 @@ class ExtensionLoaderMixin:
3. Attempts to load the extensions. Provides useful error if fails.
"""
context = kwargs.pop('context', {})
- default_extensions = ['cookiecutter.extensions.JsonifyExtension',
+
+ default_extensions = [
+ 'cookiecutter.extensions.JsonifyExtension',
'cookiecutter.extensions.RandomStringExtension',
'cookiecutter.extensions.SlugifyExtension',
'cookiecutter.extensions.TimeExtension',
- 'cookiecutter.extensions.UUIDExtension']
+ 'cookiecutter.extensions.UUIDExtension',
+ ]
extensions = default_extensions + self._read_extensions(context)
+
try:
super().__init__(extensions=extensions, **kwargs)
except ImportError as err:
@@ -37,7 +43,12 @@ class ExtensionLoaderMixin:
If context does not contain the relevant info, return an empty
list instead.
"""
- pass
+ try:
+ extensions = context['cookiecutter']['_extensions']
+ except KeyError:
+ return []
+ else:
+ return [str(ext) for ext in extensions]
class StrictEnvironment(ExtensionLoaderMixin, Environment):
diff --git a/cookiecutter/exceptions.py b/cookiecutter/exceptions.py
index 8de08a2..622e7c6 100644
--- a/cookiecutter/exceptions.py
+++ b/cookiecutter/exceptions.py
@@ -26,6 +26,8 @@ class UnknownTemplateDirException(CookiecutterException):
template, e.g. more than one dir appears to be a template dir.
"""
+ # unused locally
+
class MissingProjectDir(CookiecutterException):
"""
@@ -35,6 +37,8 @@ class MissingProjectDir(CookiecutterException):
directory inside of a repo.
"""
+ # unused locally
+
class ConfigDoesNotExistException(CookiecutterException):
"""
@@ -120,8 +124,10 @@ class UndefinedVariableInTemplate(CookiecutterException):
def __str__(self):
"""Text representation of UndefinedVariableInTemplate."""
return (
- f'{self.message}. Error message: {self.error.message}. Context: {self.context}'
- )
+ f"{self.message}. "
+ f"Error message: {self.error.message}. "
+ f"Context: {self.context}"
+ )
class UnknownExtension(CookiecutterException):
diff --git a/cookiecutter/extensions.py b/cookiecutter/extensions.py
index 8ce014a..666497c 100644
--- a/cookiecutter/extensions.py
+++ b/cookiecutter/extensions.py
@@ -1,8 +1,10 @@
"""Jinja2 extensions."""
+
import json
import string
import uuid
from secrets import choice
+
import arrow
from jinja2 import nodes
from jinja2.ext import Extension
@@ -18,6 +20,7 @@ class JsonifyExtension(Extension):
def jsonify(obj):
return json.dumps(obj, sort_keys=True, indent=4)
+
environment.filters['jsonify'] = jsonify
@@ -30,10 +33,11 @@ class RandomStringExtension(Extension):
def random_ascii_string(length, punctuation=False):
if punctuation:
- corpus = ''.join((string.ascii_letters, string.punctuation))
+ corpus = "".join((string.ascii_letters, string.punctuation))
else:
corpus = string.ascii_letters
- return ''.join(choice(corpus) for _ in range(length))
+ return "".join(choice(corpus) for _ in range(length))
+
environment.globals.update(random_ascii_string=random_ascii_string)
@@ -47,6 +51,7 @@ class SlugifyExtension(Extension):
def slugify(value, **kwargs):
"""Slugifies the value."""
return pyslugify(value, **kwargs)
+
environment.filters['slugify'] = slugify
@@ -60,18 +65,67 @@ class UUIDExtension(Extension):
def uuid4():
"""Generate UUID4."""
return str(uuid.uuid4())
+
environment.globals.update(uuid4=uuid4)
class TimeExtension(Extension):
"""Jinja2 Extension for dates and times."""
+
tags = {'now'}
def __init__(self, environment):
"""Jinja2 Extension constructor."""
super().__init__(environment)
+
environment.extend(datetime_format='%Y-%m-%d')
+ def _datetime(self, timezone, operator, offset, datetime_format):
+ d = arrow.now(timezone)
+
+ # parse shift params from offset and include operator
+ shift_params = {}
+ for param in offset.split(','):
+ interval, value = param.split('=')
+ shift_params[interval.strip()] = float(operator + value.strip())
+ d = d.shift(**shift_params)
+
+ if datetime_format is None:
+ datetime_format = self.environment.datetime_format
+ return d.strftime(datetime_format)
+
+ def _now(self, timezone, datetime_format):
+ if datetime_format is None:
+ datetime_format = self.environment.datetime_format
+ return arrow.now(timezone).strftime(datetime_format)
+
def parse(self, parser):
"""Parse datetime template and add datetime value."""
- pass
+ lineno = next(parser.stream).lineno
+
+ node = parser.parse_expression()
+
+ if parser.stream.skip_if('comma'):
+ datetime_format = parser.parse_expression()
+ else:
+ datetime_format = nodes.Const(None)
+
+ if isinstance(node, nodes.Add):
+ call_method = self.call_method(
+ '_datetime',
+ [node.left, nodes.Const('+'), node.right, datetime_format],
+ lineno=lineno,
+ )
+ elif isinstance(node, nodes.Sub):
+ call_method = self.call_method(
+ '_datetime',
+ [node.left, nodes.Const('-'), node.right, datetime_format],
+ lineno=lineno,
+ )
+ else:
+ call_method = self.call_method(
+ '_now',
+ [node, datetime_format],
+ lineno=lineno,
+ )
+ return nodes.Output([call_method], lineno=lineno)
diff --git a/cookiecutter/find.py b/cookiecutter/find.py
index 667e50d..486735f 100644
--- a/cookiecutter/find.py
+++ b/cookiecutter/find.py
@@ -1,16 +1,34 @@
"""Functions for finding Cookiecutter templates and other components."""
+
import logging
import os
from pathlib import Path
+
from jinja2 import Environment
+
from cookiecutter.exceptions import NonTemplatedInputDirException
+
logger = logging.getLogger(__name__)
-def find_template(repo_dir: 'os.PathLike[str]', env: Environment) ->Path:
+def find_template(repo_dir: "os.PathLike[str]", env: Environment) -> Path:
"""Determine which child directory of ``repo_dir`` is the project template.
:param repo_dir: Local directory of newly cloned repo.
:return: Relative path to project template.
"""
- pass
+ logger.debug('Searching %s for the project template.', repo_dir)
+
+ for str_path in os.listdir(repo_dir):
+ if (
+ 'cookiecutter' in str_path
+ and env.variable_start_string in str_path
+ and env.variable_end_string in str_path
+ ):
+ project_template = Path(repo_dir, str_path)
+ break
+ else:
+ raise NonTemplatedInputDirException
+
+ logger.debug('The project template appears to be %s', project_template)
+ return project_template
diff --git a/cookiecutter/generate.py b/cookiecutter/generate.py
index 715232e..eb3b200 100644
--- a/cookiecutter/generate.py
+++ b/cookiecutter/generate.py
@@ -1,4 +1,5 @@
"""Functions for generating a project from a project template."""
+
import fnmatch
import json
import logging
@@ -7,13 +8,25 @@ import shutil
import warnings
from collections import OrderedDict
from pathlib import Path
+
from binaryornot.check import is_binary
from jinja2 import Environment, FileSystemLoader
from jinja2.exceptions import TemplateSyntaxError, UndefinedError
-from cookiecutter.exceptions import ContextDecodingException, OutputDirExistsException, UndefinedVariableInTemplate
+
+from cookiecutter.exceptions import (
+ ContextDecodingException,
+ OutputDirExistsException,
+ UndefinedVariableInTemplate,
+)
from cookiecutter.find import find_template
from cookiecutter.hooks import run_hook_from_repo_dir
-from cookiecutter.utils import create_env_with_context, make_sure_path_exists, rmtree, work_in
+from cookiecutter.utils import (
+ create_env_with_context,
+ make_sure_path_exists,
+ rmtree,
+ work_in,
+)
+
logger = logging.getLogger(__name__)
@@ -27,17 +40,70 @@ def is_copy_only_path(path, context):
should be rendered or just copied.
:param context: cookiecutter context.
"""
- pass
+ try:
+ for dont_render in context['cookiecutter']['_copy_without_render']:
+ if fnmatch.fnmatch(path, dont_render):
+ return True
+ except KeyError:
+ return False
+
+ return False
-def apply_overwrites_to_context(context, overwrite_context, *,
- in_dictionary_variable=False):
+def apply_overwrites_to_context(
+ context, overwrite_context, *, in_dictionary_variable=False
+):
"""Modify the given context in place based on the overwrite_context."""
- pass
+ for variable, overwrite in overwrite_context.items():
+ if variable not in context:
+ if not in_dictionary_variable:
+ # We are dealing with a new variable on first level, ignore
+ continue
+ # We are dealing with a new dictionary variable in a deeper level
+ context[variable] = overwrite
+
+ context_value = context[variable]
+ if isinstance(context_value, list):
+ if in_dictionary_variable:
+ context[variable] = overwrite
+ continue
+ if isinstance(overwrite, list):
+ # We are dealing with a multichoice variable
+ # Let's confirm all choices are valid for the given context
+ if set(overwrite).issubset(set(context_value)):
+ context[variable] = overwrite
+ else:
+ raise ValueError(
+ f"{overwrite} provided for multi-choice variable "
+ f"{variable}, but valid choices are {context_value}"
+ )
+ else:
+ # We are dealing with a choice variable
+ if overwrite in context_value:
+ # This overwrite is actually valid for the given context
+ # Let's set it as default (by definition first item in list)
+ # see ``cookiecutter.prompt.prompt_choice_for_config``
+ context_value.remove(overwrite)
+ context_value.insert(0, overwrite)
+ else:
+ raise ValueError(
+ f"{overwrite} provided for choice variable "
+ f"{variable}, but the choices are {context_value}."
+ )
+ elif isinstance(context_value, dict) and isinstance(overwrite, dict):
+ # Partially overwrite some keys in original dict
+ apply_overwrites_to_context(
+ context_value, overwrite, in_dictionary_variable=True
+ )
+ context[variable] = context_value
+ else:
+ # Simply overwrite the value for this variable
+ context[variable] = overwrite
-def generate_context(context_file='cookiecutter.json', default_context=None,
- extra_context=None):
+def generate_context(
+ context_file='cookiecutter.json', default_context=None, extra_context=None
+):
"""Generate the context for a Cookiecutter project template.
Loads the JSON file as a Python object, with key being the JSON filename.
@@ -47,11 +113,42 @@ def generate_context(context_file='cookiecutter.json', default_context=None,
:param default_context: Dictionary containing config to take into account.
:param extra_context: Dictionary containing configuration overrides
"""
- pass
+ context = OrderedDict([])
+ try:
+ with open(context_file, encoding='utf-8') as file_handle:
+ obj = json.load(file_handle, object_pairs_hook=OrderedDict)
+ except ValueError as e:
+ # JSON decoding error. Let's throw a new exception that is more
+ # friendly for the developer or user.
+ full_fpath = os.path.abspath(context_file)
+ json_exc_message = str(e)
+ our_exc_message = (
+ f"JSON decoding error while loading '{full_fpath}'. "
+ f"Decoding error details: '{json_exc_message}'"
+ )
+ raise ContextDecodingException(our_exc_message) from e
-def generate_file(project_dir, infile, context, env, skip_if_file_exists=False
- ):
+ # Add the Python object to the context dictionary
+ file_name = os.path.split(context_file)[1]
+ file_stem = file_name.split('.')[0]
+ context[file_stem] = obj
+
+ # Overwrite context variable defaults with the default context from the
+ # user's global config, if available
+ if default_context:
+ try:
+ apply_overwrites_to_context(obj, default_context)
+ except ValueError as error:
+ warnings.warn(f"Invalid default received: {error}")
+ if extra_context:
+ apply_overwrites_to_context(obj, extra_context)
+
+ logger.debug('Context generated is %s', context)
+ return context
+
+
+def generate_file(project_dir, infile, context, env, skip_if_file_exists=False):
"""Render filename of infile as name of outfile, handle infile correctly.
Dealing with infile appropriately:
@@ -72,18 +169,103 @@ def generate_file(project_dir, infile, context, env, skip_if_file_exists=False
:param context: Dict for populating the cookiecutter's variables.
:param env: Jinja2 template execution environment.
"""
- pass
+ logger.debug('Processing file %s', infile)
+
+ # Render the path to the output file (not including the root project dir)
+ outfile_tmpl = env.from_string(infile)
+
+ outfile = os.path.join(project_dir, outfile_tmpl.render(**context))
+ file_name_is_empty = os.path.isdir(outfile)
+ if file_name_is_empty:
+ logger.debug('The resulting file name is empty: %s', outfile)
+ return
+
+ if skip_if_file_exists and os.path.exists(outfile):
+ logger.debug('The resulting file already exists: %s', outfile)
+ return
+
+ logger.debug('Created file at %s', outfile)
+
+ # Just copy over binary files. Don't render.
+ logger.debug("Check %s to see if it's a binary", infile)
+ if is_binary(infile):
+ logger.debug('Copying binary %s to %s without rendering', infile, outfile)
+ shutil.copyfile(infile, outfile)
+ shutil.copymode(infile, outfile)
+ return
+
+ # Force fwd slashes on Windows for get_template
+ # This is a by-design Jinja issue
+ infile_fwd_slashes = infile.replace(os.path.sep, '/')
+
+ # Render the file
+ try:
+ tmpl = env.get_template(infile_fwd_slashes)
+ except TemplateSyntaxError as exception:
+ # Disable translated so that printed exception contains verbose
+ # information about syntax error location
+ exception.translated = False
+ raise
+ rendered_file = tmpl.render(**context)
+
+ if context['cookiecutter'].get('_new_lines', False):
+ # Use `_new_lines` from context, if configured.
+ newline = context['cookiecutter']['_new_lines']
+ logger.debug('Using configured newline character %s', repr(newline))
+ else:
+ # Detect original file newline to output the rendered file.
+ # Note that newlines can be a tuple if file contains mixed line endings.
+ # In this case, we pick the first line ending we detected.
+ with open(infile, encoding='utf-8') as rd:
+ rd.readline() # Read only the first line to load a 'newlines' value.
+ newline = rd.newlines[0] if isinstance(rd.newlines, tuple) else rd.newlines
+ logger.debug('Using detected newline character %s', repr(newline))
+
+ logger.debug('Writing contents to file %s', outfile)
+
+ with open(outfile, 'w', encoding='utf-8', newline=newline) as fh:
+ fh.write(rendered_file)
+
+ # Apply file permissions to output file
+ shutil.copymode(infile, outfile)
-def render_and_create_dir(dirname: str, context: dict, output_dir:
- 'os.PathLike[str]', environment: Environment, overwrite_if_exists: bool
- =False):
+def render_and_create_dir(
+ dirname: str,
+ context: dict,
+ output_dir: "os.PathLike[str]",
+ environment: Environment,
+ overwrite_if_exists: bool = False,
+):
"""Render name of a directory, create the directory, return its path."""
- pass
+ name_tmpl = environment.from_string(dirname)
+ rendered_dirname = name_tmpl.render(**context)
+ dir_to_create = Path(output_dir, rendered_dirname)
-def _run_hook_from_repo_dir(repo_dir, hook_name, project_dir, context,
- delete_project_on_failure):
+ logger.debug(
+ 'Rendered dir %s must exist in output_dir %s', dir_to_create, output_dir
+ )
+
+ output_dir_exists = dir_to_create.exists()
+
+ if output_dir_exists:
+ if overwrite_if_exists:
+ logger.debug(
+ 'Output directory %s already exists, overwriting it', dir_to_create
+ )
+ else:
+ msg = f'Error: "{dir_to_create}" directory already exists'
+ raise OutputDirExistsException(msg)
+ else:
+ make_sure_path_exists(dir_to_create)
+
+ return dir_to_create, not output_dir_exists
+
+
+def _run_hook_from_repo_dir(
+ repo_dir, hook_name, project_dir, context, delete_project_on_failure
+):
"""Run hook from repo directory, clean project directory if hook fails.
:param repo_dir: Project template input directory.
@@ -93,12 +275,26 @@ def _run_hook_from_repo_dir(repo_dir, hook_name, project_dir, context,
:param delete_project_on_failure: Delete the project directory on hook
failure?
"""
- pass
+ warnings.warn(
+ "The '_run_hook_from_repo_dir' function is deprecated, "
+ "use 'cookiecutter.hooks.run_hook_from_repo_dir' instead",
+ DeprecationWarning,
+ 2,
+ )
+ run_hook_from_repo_dir(
+ repo_dir, hook_name, project_dir, context, delete_project_on_failure
+ )
-def generate_files(repo_dir, context=None, output_dir='.',
- overwrite_if_exists=False, skip_if_file_exists=False, accept_hooks=True,
- keep_project_on_failure=False):
+def generate_files(
+ repo_dir,
+ context=None,
+ output_dir='.',
+ overwrite_if_exists=False,
+ skip_if_file_exists=False,
+ accept_hooks=True,
+ keep_project_on_failure=False,
+):
"""Render the templates and saves them to files.
:param repo_dir: Project template input directory.
@@ -112,4 +308,120 @@ def generate_files(repo_dir, context=None, output_dir='.',
:param keep_project_on_failure: If `True` keep generated project directory even when
generation fails
"""
- pass
+ context = context or OrderedDict([])
+
+ env = create_env_with_context(context)
+
+ template_dir = find_template(repo_dir, env)
+ logger.debug('Generating project from %s...', template_dir)
+
+ unrendered_dir = os.path.split(template_dir)[1]
+ try:
+ project_dir, output_directory_created = render_and_create_dir(
+ unrendered_dir, context, output_dir, env, overwrite_if_exists
+ )
+ except UndefinedError as err:
+ msg = f"Unable to create project directory '{unrendered_dir}'"
+ raise UndefinedVariableInTemplate(msg, err, context) from err
+
+ # We want the Jinja path and the OS paths to match. Consequently, we'll:
+ # + CD to the template folder
+ # + Set Jinja's path to '.'
+ #
+ # In order to build our files to the correct folder(s), we'll use an
+ # absolute path for the target folder (project_dir)
+
+ project_dir = os.path.abspath(project_dir)
+ logger.debug('Project directory is %s', project_dir)
+
+ # if we created the output directory, then it's ok to remove it
+ # if rendering fails
+ delete_project_on_failure = output_directory_created and not keep_project_on_failure
+
+ if accept_hooks:
+ run_hook_from_repo_dir(
+ repo_dir, 'pre_gen_project', project_dir, context, delete_project_on_failure
+ )
+
+ with work_in(template_dir):
+ env.loader = FileSystemLoader(['.', '../templates'])
+
+ for root, dirs, files in os.walk('.'):
+ # We must separate the two types of dirs into different lists.
+ # The reason is that we don't want ``os.walk`` to go through the
+ # unrendered directories, since they will just be copied.
+ copy_dirs = []
+ render_dirs = []
+
+ for d in dirs:
+ d_ = os.path.normpath(os.path.join(root, d))
+ # We check the full path, because that's how it can be
+ # specified in the ``_copy_without_render`` setting, but
+ # we store just the dir name
+ if is_copy_only_path(d_, context):
+ logger.debug('Found copy only path %s', d)
+ copy_dirs.append(d)
+ else:
+ render_dirs.append(d)
+
+ for copy_dir in copy_dirs:
+ indir = os.path.normpath(os.path.join(root, copy_dir))
+ outdir = os.path.normpath(os.path.join(project_dir, indir))
+ outdir = env.from_string(outdir).render(**context)
+ logger.debug('Copying dir %s to %s without rendering', indir, outdir)
+
+ # The outdir is not the root dir, it is the dir which marked as copy
+ # only in the config file. If the program hits this line, which means
+ # the overwrite_if_exists = True, and root dir exists
+ if os.path.isdir(outdir):
+ shutil.rmtree(outdir)
+ shutil.copytree(indir, outdir)
+
+ # We mutate ``dirs``, because we only want to go through these dirs
+ # recursively
+ dirs[:] = render_dirs
+ for d in dirs:
+ unrendered_dir = os.path.join(project_dir, root, d)
+ try:
+ render_and_create_dir(
+ unrendered_dir, context, output_dir, env, overwrite_if_exists
+ )
+ except UndefinedError as err:
+ if delete_project_on_failure:
+ rmtree(project_dir)
+ _dir = os.path.relpath(unrendered_dir, output_dir)
+ msg = f"Unable to create directory '{_dir}'"
+ raise UndefinedVariableInTemplate(msg, err, context) from err
+
+ for f in files:
+ infile = os.path.normpath(os.path.join(root, f))
+ if is_copy_only_path(infile, context):
+ outfile_tmpl = env.from_string(infile)
+ outfile_rendered = outfile_tmpl.render(**context)
+ outfile = os.path.join(project_dir, outfile_rendered)
+ logger.debug(
+ 'Copying file %s to %s without rendering', infile, outfile
+ )
+ shutil.copyfile(infile, outfile)
+ shutil.copymode(infile, outfile)
+ continue
+ try:
+ generate_file(
+ project_dir, infile, context, env, skip_if_file_exists
+ )
+ except UndefinedError as err:
+ if delete_project_on_failure:
+ rmtree(project_dir)
+ msg = f"Unable to create file '{infile}'"
+ raise UndefinedVariableInTemplate(msg, err, context) from err
+
+ if accept_hooks:
+ run_hook_from_repo_dir(
+ repo_dir,
+ 'post_gen_project',
+ project_dir,
+ context,
+ delete_project_on_failure,
+ )
+
+ return project_dir
diff --git a/cookiecutter/hooks.py b/cookiecutter/hooks.py
index 0aa9c52..16b0647 100644
--- a/cookiecutter/hooks.py
+++ b/cookiecutter/hooks.py
@@ -1,17 +1,31 @@
"""Functions for discovering and executing various cookiecutter hooks."""
+
import errno
import logging
import os
-import subprocess
+import subprocess # nosec
import sys
import tempfile
from pathlib import Path
+
from jinja2.exceptions import UndefinedError
+
from cookiecutter import utils
from cookiecutter.exceptions import FailedHookException
-from cookiecutter.utils import create_env_with_context, create_tmp_repo_dir, rmtree, work_in
+from cookiecutter.utils import (
+ create_env_with_context,
+ create_tmp_repo_dir,
+ rmtree,
+ work_in,
+)
+
logger = logging.getLogger(__name__)
-_HOOKS = ['pre_prompt', 'pre_gen_project', 'post_gen_project']
+
+_HOOKS = [
+ 'pre_prompt',
+ 'pre_gen_project',
+ 'post_gen_project',
+]
EXIT_SUCCESS = 0
@@ -22,7 +36,13 @@ def valid_hook(hook_file, hook_name):
:param hook_name: The hook to find
:return: The hook file validity
"""
- pass
+ filename = os.path.basename(hook_file)
+ basename = os.path.splitext(filename)[0]
+ matching_hook = basename == hook_name
+ supported_hook = basename in _HOOKS
+ backup_file = filename.endswith('~')
+
+ return matching_hook and supported_hook and not backup_file
def find_hook(hook_name, hooks_dir='hooks'):
@@ -37,7 +57,20 @@ def find_hook(hook_name, hooks_dir='hooks'):
:param hooks_dir: The hook directory in the template
:return: The absolute path to the hook script or None
"""
- pass
+ logger.debug('hooks_dir is %s', os.path.abspath(hooks_dir))
+
+ if not os.path.isdir(hooks_dir):
+ logger.debug('No hooks/dir in template_dir')
+ return None
+
+ scripts = []
+ for hook_file in os.listdir(hooks_dir):
+ if valid_hook(hook_file, hook_name):
+ scripts.append(os.path.abspath(os.path.join(hooks_dir, hook_file)))
+
+ if len(scripts) == 0:
+ return None
+ return scripts
def run_script(script_path, cwd='.'):
@@ -46,7 +79,27 @@ def run_script(script_path, cwd='.'):
:param script_path: Absolute path to the script to run.
:param cwd: The directory to run the script from.
"""
- pass
+ run_thru_shell = sys.platform.startswith('win')
+ if script_path.endswith('.py'):
+ script_command = [sys.executable, script_path]
+ else:
+ script_command = [script_path]
+
+ utils.make_executable(script_path)
+
+ try:
+ proc = subprocess.Popen(script_command, shell=run_thru_shell, cwd=cwd) # nosec
+ exit_status = proc.wait()
+ if exit_status != EXIT_SUCCESS:
+ raise FailedHookException(
+ f'Hook script failed (exit status: {exit_status})'
+ )
+ except OSError as err:
+ if err.errno == errno.ENOEXEC:
+ raise FailedHookException(
+ 'Hook script failed, might be an empty file or missing a shebang'
+ ) from err
+ raise FailedHookException(f'Hook script failed (error: {err})') from err
def run_script_with_context(script_path, cwd, context):
@@ -56,7 +109,18 @@ def run_script_with_context(script_path, cwd, context):
:param cwd: The directory to run the script from.
:param context: Cookiecutter project template context.
"""
- pass
+ _, extension = os.path.splitext(script_path)
+
+ with open(script_path, encoding='utf-8') as file:
+ contents = file.read()
+
+ with tempfile.NamedTemporaryFile(delete=False, mode='wb', suffix=extension) as temp:
+ env = create_env_with_context(context)
+ template = env.from_string(contents)
+ output = template.render(**context)
+ temp.write(output.encode('utf-8'))
+
+ run_script(temp.name, cwd)
def run_hook(hook_name, project_dir, context):
@@ -67,11 +131,18 @@ def run_hook(hook_name, project_dir, context):
:param project_dir: The directory to execute the script from.
:param context: Cookiecutter project context.
"""
- pass
-
-
-def run_hook_from_repo_dir(repo_dir, hook_name, project_dir, context,
- delete_project_on_failure):
+ scripts = find_hook(hook_name)
+ if not scripts:
+ logger.debug('No %s hook found', hook_name)
+ return
+ logger.debug('Running hook %s', hook_name)
+ for script in scripts:
+ run_script_with_context(script, project_dir, context)
+
+
+def run_hook_from_repo_dir(
+ repo_dir, hook_name, project_dir, context, delete_project_on_failure
+):
"""Run hook from repo directory, clean project directory if hook fails.
:param repo_dir: Project template input directory.
@@ -81,12 +152,41 @@ def run_hook_from_repo_dir(repo_dir, hook_name, project_dir, context,
:param delete_project_on_failure: Delete the project directory on hook
failure?
"""
- pass
-
-
-def run_pre_prompt_hook(repo_dir: 'os.PathLike[str]') ->Path:
+ with work_in(repo_dir):
+ try:
+ run_hook(hook_name, project_dir, context)
+ except (
+ FailedHookException,
+ UndefinedError,
+ ):
+ if delete_project_on_failure:
+ rmtree(project_dir)
+ logger.error(
+ "Stopping generation because %s hook "
+ "script didn't exit successfully",
+ hook_name,
+ )
+ raise
+
+
+def run_pre_prompt_hook(repo_dir: "os.PathLike[str]") -> Path:
"""Run pre_prompt hook from repo directory.
:param repo_dir: Project template input directory.
"""
- pass
+ # Check if we have a valid pre_prompt script
+ with work_in(repo_dir):
+ scripts = find_hook('pre_prompt')
+ if not scripts:
+ return repo_dir
+
+ # Create a temporary directory
+ repo_dir = create_tmp_repo_dir(repo_dir)
+ with work_in(repo_dir):
+ scripts = find_hook('pre_prompt')
+ for script in scripts:
+ try:
+ run_script(script, repo_dir)
+ except FailedHookException:
+ raise FailedHookException('Pre-Prompt Hook script failed')
+ return repo_dir
diff --git a/cookiecutter/log.py b/cookiecutter/log.py
index 894c633..c2ac283 100644
--- a/cookiecutter/log.py
+++ b/cookiecutter/log.py
@@ -1,10 +1,20 @@
"""Module for setting up logging."""
+
import logging
import sys
-LOG_LEVELS = {'DEBUG': logging.DEBUG, 'INFO': logging.INFO, 'WARNING':
- logging.WARNING, 'ERROR': logging.ERROR, 'CRITICAL': logging.CRITICAL}
-LOG_FORMATS = {'DEBUG': '%(levelname)s %(name)s: %(message)s', 'INFO':
- '%(levelname)s: %(message)s'}
+
+LOG_LEVELS = {
+ 'DEBUG': logging.DEBUG,
+ 'INFO': logging.INFO,
+ 'WARNING': logging.WARNING,
+ 'ERROR': logging.ERROR,
+ 'CRITICAL': logging.CRITICAL,
+}
+
+LOG_FORMATS = {
+ 'DEBUG': '%(levelname)s %(name)s: %(message)s',
+ 'INFO': '%(levelname)s: %(message)s',
+}
def configure_logger(stream_level='DEBUG', debug_file=None):
@@ -13,4 +23,30 @@ def configure_logger(stream_level='DEBUG', debug_file=None):
Set up logging to stdout with given level. If ``debug_file`` is given set
up logging to file with DEBUG level.
"""
- pass
+ # Set up 'cookiecutter' logger
+ logger = logging.getLogger('cookiecutter')
+ logger.setLevel(logging.DEBUG)
+
+ # Remove all attached handlers, in case there was
+ # a logger with using the name 'cookiecutter'
+ del logger.handlers[:]
+
+ # Create a file handler if a log file is provided
+ if debug_file is not None:
+ debug_formatter = logging.Formatter(LOG_FORMATS['DEBUG'])
+ file_handler = logging.FileHandler(debug_file)
+ file_handler.setLevel(LOG_LEVELS['DEBUG'])
+ file_handler.setFormatter(debug_formatter)
+ logger.addHandler(file_handler)
+
+ # Get settings based on the given stream_level
+ log_formatter = logging.Formatter(LOG_FORMATS[stream_level])
+ log_level = LOG_LEVELS[stream_level]
+
+ # Create a stream handler
+ stream_handler = logging.StreamHandler(stream=sys.stdout)
+ stream_handler.setLevel(log_level)
+ stream_handler.setFormatter(log_formatter)
+ logger.addHandler(stream_handler)
+
+ return logger
diff --git a/cookiecutter/main.py b/cookiecutter/main.py
index 4b1087d..2146c1b 100644
--- a/cookiecutter/main.py
+++ b/cookiecutter/main.py
@@ -4,11 +4,13 @@ Main entry point for the `cookiecutter` command.
The code in this module is also a good example of how to use Cookiecutter as a
library rather than a script.
"""
+
import logging
import os
import sys
from copy import copy
from pathlib import Path
+
from cookiecutter.config import get_user_config
from cookiecutter.exceptions import InvalidModeException
from cookiecutter.generate import generate_context, generate_files
@@ -17,14 +19,26 @@ from cookiecutter.prompt import choose_nested_template, prompt_for_config
from cookiecutter.replay import dump, load
from cookiecutter.repository import determine_repo_dir
from cookiecutter.utils import rmtree
+
logger = logging.getLogger(__name__)
-def cookiecutter(template, checkout=None, no_input=False, extra_context=
- None, replay=None, overwrite_if_exists=False, output_dir='.',
- config_file=None, default_config=False, password=None, directory=None,
- skip_if_file_exists=False, accept_hooks=True, keep_project_on_failure=False
- ):
+def cookiecutter(
+ template,
+ checkout=None,
+ no_input=False,
+ extra_context=None,
+ replay=None,
+ overwrite_if_exists=False,
+ output_dir='.',
+ config_file=None,
+ default_config=False,
+ password=None,
+ directory=None,
+ skip_if_file_exists=False,
+ accept_hooks=True,
+ keep_project_on_failure=False,
+):
"""
Run Cookiecutter just as if using it from the command line.
@@ -52,14 +66,140 @@ def cookiecutter(template, checkout=None, no_input=False, extra_context=
:param keep_project_on_failure: If `True` keep generated project directory even when
generation fails
"""
- pass
+ if replay and ((no_input is not False) or (extra_context is not None)):
+ err_msg = (
+ "You can not use both replay and no_input or extra_context "
+ "at the same time."
+ )
+ raise InvalidModeException(err_msg)
+ config_dict = get_user_config(
+ config_file=config_file,
+ default_config=default_config,
+ )
+ base_repo_dir, cleanup_base_repo_dir = determine_repo_dir(
+ template=template,
+ abbreviations=config_dict['abbreviations'],
+ clone_to_dir=config_dict['cookiecutters_dir'],
+ checkout=checkout,
+ no_input=no_input,
+ password=password,
+ directory=directory,
+ )
+ repo_dir, cleanup = base_repo_dir, cleanup_base_repo_dir
+ # Run pre_prompt hook
+ repo_dir = run_pre_prompt_hook(base_repo_dir) if accept_hooks else repo_dir
+ # Always remove temporary dir if it was created
+ cleanup = True if repo_dir != base_repo_dir else False
-class _patch_import_path_for_repo:
+ import_patch = _patch_import_path_for_repo(repo_dir)
+ template_name = os.path.basename(os.path.abspath(repo_dir))
+ if replay:
+ with import_patch:
+ if isinstance(replay, bool):
+ context_from_replayfile = load(config_dict['replay_dir'], template_name)
+ else:
+ path, template_name = os.path.split(os.path.splitext(replay)[0])
+ context_from_replayfile = load(path, template_name)
+
+ context_file = os.path.join(repo_dir, 'cookiecutter.json')
+ logger.debug('context_file is %s', context_file)
+
+ if replay:
+ context = generate_context(
+ context_file=context_file,
+ default_context=config_dict['default_context'],
+ extra_context=None,
+ )
+ logger.debug('replayfile context: %s', context_from_replayfile)
+ items_for_prompting = {
+ k: v
+ for k, v in context['cookiecutter'].items()
+ if k not in context_from_replayfile['cookiecutter'].keys()
+ }
+ context_for_prompting = {}
+ context_for_prompting['cookiecutter'] = items_for_prompting
+ context = context_from_replayfile
+ logger.debug('prompting context: %s', context_for_prompting)
+ else:
+ context = generate_context(
+ context_file=context_file,
+ default_context=config_dict['default_context'],
+ extra_context=extra_context,
+ )
+ context_for_prompting = context
+ # preserve the original cookiecutter options
+ # print(context['cookiecutter'])
+ context['_cookiecutter'] = {
+ k: v for k, v in context['cookiecutter'].items() if not k.startswith("_")
+ }
+
+ # prompt the user to manually configure at the command line.
+ # except when 'no-input' flag is set
+
+ with import_patch:
+ if {"template", "templates"} & set(context["cookiecutter"].keys()):
+ nested_template = choose_nested_template(context, repo_dir, no_input)
+ return cookiecutter(
+ template=nested_template,
+ checkout=checkout,
+ no_input=no_input,
+ extra_context=extra_context,
+ replay=replay,
+ overwrite_if_exists=overwrite_if_exists,
+ output_dir=output_dir,
+ config_file=config_file,
+ default_config=default_config,
+ password=password,
+ directory=directory,
+ skip_if_file_exists=skip_if_file_exists,
+ accept_hooks=accept_hooks,
+ keep_project_on_failure=keep_project_on_failure,
+ )
+ if context_for_prompting['cookiecutter']:
+ context['cookiecutter'].update(
+ prompt_for_config(context_for_prompting, no_input)
+ )
+
+ logger.debug('context is %s', context)
- def __init__(self, repo_dir: 'os.PathLike[str]'):
- self._repo_dir = f'{repo_dir}' if isinstance(repo_dir, Path
- ) else repo_dir
+ # include template dir or url in the context dict
+ context['cookiecutter']['_template'] = template
+
+ # include output+dir in the context dict
+ context['cookiecutter']['_output_dir'] = os.path.abspath(output_dir)
+
+ # include repo dir or url in the context dict
+ context['cookiecutter']['_repo_dir'] = f"{repo_dir}"
+
+ # include checkout details in the context dict
+ context['cookiecutter']['_checkout'] = checkout
+
+ dump(config_dict['replay_dir'], template_name, context)
+
+ # Create project from local context and project template.
+ with import_patch:
+ result = generate_files(
+ repo_dir=repo_dir,
+ context=context,
+ overwrite_if_exists=overwrite_if_exists,
+ skip_if_file_exists=skip_if_file_exists,
+ output_dir=output_dir,
+ accept_hooks=accept_hooks,
+ keep_project_on_failure=keep_project_on_failure,
+ )
+
+ # Cleanup (if required)
+ if cleanup:
+ rmtree(repo_dir)
+ if cleanup_base_repo_dir:
+ rmtree(base_repo_dir)
+ return result
+
+
+class _patch_import_path_for_repo:
+ def __init__(self, repo_dir: "os.PathLike[str]"):
+ self._repo_dir = f"{repo_dir}" if isinstance(repo_dir, Path) else repo_dir
self._path = None
def __enter__(self):
diff --git a/cookiecutter/prompt.py b/cookiecutter/prompt.py
index 2bcc55f..761ac99 100644
--- a/cookiecutter/prompt.py
+++ b/cookiecutter/prompt.py
@@ -1,36 +1,57 @@
"""Functions for prompting the user for project info."""
+
import json
import os
import re
import sys
from collections import OrderedDict
from pathlib import Path
+
from jinja2.exceptions import UndefinedError
from rich.prompt import Confirm, InvalidResponse, Prompt, PromptBase
+
from cookiecutter.exceptions import UndefinedVariableInTemplate
from cookiecutter.utils import create_env_with_context, rmtree
-def read_user_variable(var_name, default_value, prompts=None, prefix=''):
+def read_user_variable(var_name, default_value, prompts=None, prefix=""):
"""Prompt user for variable and return the entered value or given default.
:param str var_name: Variable of the context to query the user
:param default_value: Value that will be returned if no input happens
"""
- pass
+ question = (
+ prompts[var_name]
+ if prompts and var_name in prompts.keys() and prompts[var_name]
+ else var_name
+ )
+
+ while True:
+ variable = Prompt.ask(f"{prefix}{question}", default=default_value)
+ if variable is not None:
+ break
+
+ return variable
class YesNoPrompt(Confirm):
"""A prompt that returns a boolean for yes/no questions."""
- yes_choices = ['1', 'true', 't', 'yes', 'y', 'on']
- no_choices = ['0', 'false', 'f', 'no', 'n', 'off']
- def process_response(self, value: str) ->bool:
+ yes_choices = ["1", "true", "t", "yes", "y", "on"]
+ no_choices = ["0", "false", "f", "no", "n", "off"]
+
+ def process_response(self, value: str) -> bool:
"""Convert choices to a bool."""
- pass
+ value = value.strip().lower()
+ if value in self.yes_choices:
+ return True
+ elif value in self.no_choices:
+ return False
+ else:
+ raise InvalidResponse(self.validate_error_message)
-def read_user_yes_no(var_name, default_value, prompts=None, prefix=''):
+def read_user_yes_no(var_name, default_value, prompts=None, prefix=""):
"""Prompt the user to reply with 'yes' or 'no' (or equivalent values).
- These input values will be converted to ``True``:
@@ -44,7 +65,12 @@ def read_user_yes_no(var_name, default_value, prompts=None, prefix=''):
:param str question: Question to the user
:param default_value: Value that will be returned if no input happens
"""
- pass
+ question = (
+ prompts[var_name]
+ if prompts and var_name in prompts.keys() and prompts[var_name]
+ else var_name
+ )
+ return YesNoPrompt.ask(f"{prefix}{question}", default=default_value)
def read_repo_password(question):
@@ -52,10 +78,10 @@ def read_repo_password(question):
:param str question: Question to the user
"""
- pass
+ return Prompt.ask(question, password=True)
-def read_user_choice(var_name, options, prompts=None, prefix=''):
+def read_user_choice(var_name, options, prompts=None, prefix=""):
"""Prompt the user to choose from several options for the given variable.
The first item will be returned if no input happens.
@@ -64,7 +90,46 @@ def read_user_choice(var_name, options, prompts=None, prefix=''):
:param list options: Sequence of options that are available to select from
:return: Exactly one item of ``options`` that has been chosen by the user
"""
- pass
+ if not isinstance(options, list):
+ raise TypeError
+
+ if not options:
+ raise ValueError
+
+ choice_map = OrderedDict((f'{i}', value) for i, value in enumerate(options, 1))
+ choices = choice_map.keys()
+
+ question = f"Select {var_name}"
+ choice_lines = [
+ ' [bold magenta]{}[/] - [bold]{}[/]'.format(*c) for c in choice_map.items()
+ ]
+
+ # Handle if human-readable prompt is provided
+ if prompts and var_name in prompts.keys():
+ if isinstance(prompts[var_name], str):
+ question = prompts[var_name]
+ else:
+ if "__prompt__" in prompts[var_name]:
+ question = prompts[var_name]["__prompt__"]
+ choice_lines = [
+ (
+ f" [bold magenta]{i}[/] - [bold]{prompts[var_name][p]}[/]"
+ if p in prompts[var_name]
+ else f" [bold magenta]{i}[/] - [bold]{p}[/]"
+ )
+ for i, p in choice_map.items()
+ ]
+
+ prompt = '\n'.join(
+ (
+ f"{prefix}{question}",
+ "\n".join(choice_lines),
+ " Choose from",
+ )
+ )
+
+ user_choice = Prompt.ask(prompt, choices=list(choices), default=list(choices)[0])
+ return choice_map[user_choice]
DEFAULT_DISPLAY = 'default'
@@ -75,29 +140,52 @@ def process_json(user_value, default_value=None):
:param str user_value: User-supplied value to load as a JSON dict
"""
- pass
+ try:
+ user_dict = json.loads(user_value, object_pairs_hook=OrderedDict)
+ except Exception as error:
+ # Leave it up to click to ask the user again
+ raise InvalidResponse('Unable to decode to JSON.') from error
+
+ if not isinstance(user_dict, dict):
+ # Leave it up to click to ask the user again
+ raise InvalidResponse('Requires JSON dict.')
+
+ return user_dict
class JsonPrompt(PromptBase[dict]):
"""A prompt that returns a dict from JSON string."""
+
default = None
response_type = dict
- validate_error_message = (
- '[prompt.invalid] Please enter a valid JSON string')
+ validate_error_message = "[prompt.invalid] Please enter a valid JSON string"
- def process_response(self, value: str) ->dict:
+ def process_response(self, value: str) -> dict:
"""Convert choices to a dict."""
- pass
+ return process_json(value, self.default)
-def read_user_dict(var_name, default_value, prompts=None, prefix=''):
+def read_user_dict(var_name, default_value, prompts=None, prefix=""):
"""Prompt the user to provide a dictionary of data.
:param str var_name: Variable as specified in the context
:param default_value: Value that will be returned if no input is provided
:return: A Python dictionary to use in the context.
"""
- pass
+ if not isinstance(default_value, dict):
+ raise TypeError
+
+ question = (
+ prompts[var_name]
+ if prompts and var_name in prompts.keys() and prompts[var_name]
+ else var_name
+ )
+ user_value = JsonPrompt.ask(
+ f"{prefix}{question} [cyan bold]({DEFAULT_DISPLAY})[/]",
+ default=default_value,
+ show_default=False,
+ )
+ return user_value
def render_variable(env, raw, cookiecutter_dict):
@@ -117,12 +205,34 @@ def render_variable(env, raw, cookiecutter_dict):
being populated with variables.
:return: The rendered value for the default variable.
"""
- pass
-
-
-def _prompts_from_options(options: dict) ->dict:
+ if raw is None or isinstance(raw, bool):
+ return raw
+ elif isinstance(raw, dict):
+ return {
+ render_variable(env, k, cookiecutter_dict): render_variable(
+ env, v, cookiecutter_dict
+ )
+ for k, v in raw.items()
+ }
+ elif isinstance(raw, list):
+ return [render_variable(env, v, cookiecutter_dict) for v in raw]
+ elif not isinstance(raw, str):
+ raw = str(raw)
+
+ template = env.from_string(raw)
+
+ return template.render(cookiecutter=cookiecutter_dict)
+
+
+def _prompts_from_options(options: dict) -> dict:
"""Process template options and return friendly prompt information."""
- pass
+ prompts = {"__prompt__": "Select a template"}
+ for option_key, option_value in options.items():
+ title = str(option_value.get("title", option_key))
+ description = option_value.get("description", option_key)
+ label = title if title == description else f"{title} ({description})"
+ prompts[option_key] = label
+ return prompts
def prompt_choice_for_template(key, options, no_input):
@@ -130,16 +240,22 @@ def prompt_choice_for_template(key, options, no_input):
:param no_input: Do not prompt for user input and return the first available option.
"""
- pass
+ opts = list(options.keys())
+ prompts = {"templates": _prompts_from_options(options)}
+ return opts[0] if no_input else read_user_choice(key, opts, prompts, "")
-def prompt_choice_for_config(cookiecutter_dict, env, key, options, no_input,
- prompts=None, prefix=''):
+def prompt_choice_for_config(
+ cookiecutter_dict, env, key, options, no_input, prompts=None, prefix=""
+):
"""Prompt user with a set of options to choose from.
:param no_input: Do not prompt for user input and return the first available option.
"""
- pass
+ rendered_options = [render_variable(env, raw, cookiecutter_dict) for raw in options]
+ if no_input:
+ return rendered_options[0]
+ return read_user_choice(key, rendered_options, prompts, prefix)
def prompt_for_config(context, no_input=False):
@@ -148,11 +264,81 @@ def prompt_for_config(context, no_input=False):
:param dict context: Source for field names and sample values.
:param no_input: Do not prompt for user input and use only values from context.
"""
- pass
-
-
-def choose_nested_template(context: dict, repo_dir: str, no_input: bool=False
- ) ->str:
+ cookiecutter_dict = OrderedDict([])
+ env = create_env_with_context(context)
+ prompts = context['cookiecutter'].pop('__prompts__', {})
+
+ # First pass: Handle simple and raw variables, plus choices.
+ # These must be done first because the dictionaries keys and
+ # values might refer to them.
+ count = 0
+ all_prompts = context['cookiecutter'].items()
+ visible_prompts = [k for k, _ in all_prompts if not k.startswith("_")]
+ size = len(visible_prompts)
+ for key, raw in all_prompts:
+ if key.startswith('_') and not key.startswith('__'):
+ cookiecutter_dict[key] = raw
+ continue
+ elif key.startswith('__'):
+ cookiecutter_dict[key] = render_variable(env, raw, cookiecutter_dict)
+ continue
+
+ if not isinstance(raw, dict):
+ count += 1
+ prefix = f" [dim][{count}/{size}][/] "
+
+ try:
+ if isinstance(raw, list):
+ # We are dealing with a choice variable
+ val = prompt_choice_for_config(
+ cookiecutter_dict, env, key, raw, no_input, prompts, prefix
+ )
+ cookiecutter_dict[key] = val
+ elif isinstance(raw, bool):
+ # We are dealing with a boolean variable
+ if no_input:
+ cookiecutter_dict[key] = render_variable(
+ env, raw, cookiecutter_dict
+ )
+ else:
+ cookiecutter_dict[key] = read_user_yes_no(key, raw, prompts, prefix)
+ elif not isinstance(raw, dict):
+ # We are dealing with a regular variable
+ val = render_variable(env, raw, cookiecutter_dict)
+
+ if not no_input:
+ val = read_user_variable(key, val, prompts, prefix)
+
+ cookiecutter_dict[key] = val
+ except UndefinedError as err:
+ msg = f"Unable to render variable '{key}'"
+ raise UndefinedVariableInTemplate(msg, err, context) from err
+
+ # Second pass; handle the dictionaries.
+ for key, raw in context['cookiecutter'].items():
+ # Skip private type dicts not to be rendered.
+ if key.startswith('_') and not key.startswith('__'):
+ continue
+
+ try:
+ if isinstance(raw, dict):
+ # We are dealing with a dict variable
+ count += 1
+ prefix = f" [dim][{count}/{size}][/] "
+ val = render_variable(env, raw, cookiecutter_dict)
+
+ if not no_input and not key.startswith('__'):
+ val = read_user_dict(key, val, prompts, prefix)
+
+ cookiecutter_dict[key] = val
+ except UndefinedError as err:
+ msg = f"Unable to render variable '{key}'"
+ raise UndefinedVariableInTemplate(msg, err, context) from err
+
+ return cookiecutter_dict
+
+
+def choose_nested_template(context: dict, repo_dir: str, no_input: bool = False) -> str:
"""Prompt user to select the nested template to use.
:param context: Source for field names and sample values.
@@ -160,7 +346,33 @@ def choose_nested_template(context: dict, repo_dir: str, no_input: bool=False
:param no_input: Do not prompt for user input and use only values from context.
:returns: Path to the selected template.
"""
- pass
+ cookiecutter_dict = OrderedDict([])
+ env = create_env_with_context(context)
+ prefix = ""
+ prompts = context['cookiecutter'].pop('__prompts__', {})
+ key = "templates"
+ config = context['cookiecutter'].get(key, {})
+ if config:
+ # Pass
+ val = prompt_choice_for_template(key, config, no_input)
+ template = config[val]["path"]
+ else:
+ # Old style
+ key = "template"
+ config = context['cookiecutter'].get(key, [])
+ val = prompt_choice_for_config(
+ cookiecutter_dict, env, key, config, no_input, prompts, prefix
+ )
+ template = re.search(r'\((.+)\)', val).group(1)
+
+ template = Path(template) if template else None
+ if not (template and not template.is_absolute()):
+ raise ValueError("Illegal template path")
+
+ repo_dir = Path(repo_dir).resolve()
+ template_path = (repo_dir / template).resolve()
+ # Return path as string
+ return f"{template_path}"
def prompt_and_delete(path, no_input=False):
@@ -174,4 +386,28 @@ def prompt_and_delete(path, no_input=False):
:param no_input: Suppress prompt to delete repo and just delete it.
:return: True if the content was deleted
"""
- pass
+ # Suppress prompt if called via API
+ if no_input:
+ ok_to_delete = True
+ else:
+ question = (
+ f"You've downloaded {path} before. Is it okay to delete and re-download it?"
+ )
+
+ ok_to_delete = read_user_yes_no(question, 'yes')
+
+ if ok_to_delete:
+ if os.path.isdir(path):
+ rmtree(path)
+ else:
+ os.remove(path)
+ return True
+ else:
+ ok_to_reuse = read_user_yes_no(
+ "Do you want to re-use the existing version?", 'yes'
+ )
+
+ if ok_to_reuse:
+ return False
+
+ sys.exit()
diff --git a/cookiecutter/replay.py b/cookiecutter/replay.py
index 340be41..196f2b1 100644
--- a/cookiecutter/replay.py
+++ b/cookiecutter/replay.py
@@ -3,21 +3,50 @@ cookiecutter.replay.
-------------------
"""
+
import json
import os
+
from cookiecutter.utils import make_sure_path_exists
def get_file_name(replay_dir, template_name):
"""Get the name of file."""
- pass
+ suffix = '.json' if not template_name.endswith('.json') else ''
+ file_name = f'{template_name}{suffix}'
+ return os.path.join(replay_dir, file_name)
-def dump(replay_dir: 'os.PathLike[str]', template_name: str, context: dict):
+def dump(replay_dir: "os.PathLike[str]", template_name: str, context: dict):
"""Write json data to file."""
- pass
+ make_sure_path_exists(replay_dir)
+
+ if not isinstance(template_name, str):
+ raise TypeError('Template name is required to be of type str')
+
+ if not isinstance(context, dict):
+ raise TypeError('Context is required to be of type dict')
+
+ if 'cookiecutter' not in context:
+ raise ValueError('Context is required to contain a cookiecutter key')
+
+ replay_file = get_file_name(replay_dir, template_name)
+
+ with open(replay_file, 'w', encoding="utf-8") as outfile:
+ json.dump(context, outfile, indent=2)
def load(replay_dir, template_name):
"""Read json data from file."""
- pass
+ if not isinstance(template_name, str):
+ raise TypeError('Template name is required to be of type str')
+
+ replay_file = get_file_name(replay_dir, template_name)
+
+ with open(replay_file, encoding="utf-8") as infile:
+ context = json.load(infile)
+
+ if 'cookiecutter' not in context:
+ raise ValueError('Context is required to contain a cookiecutter key')
+
+ return context
diff --git a/cookiecutter/repository.py b/cookiecutter/repository.py
index e350c56..cc5576a 100644
--- a/cookiecutter/repository.py
+++ b/cookiecutter/repository.py
@@ -1,28 +1,32 @@
"""Cookiecutter repository functions."""
+
import os
import re
+
from cookiecutter.exceptions import RepositoryNotFound
from cookiecutter.vcs import clone
from cookiecutter.zipfile import unzip
+
REPO_REGEX = re.compile(
- """
+ r"""
# something like git:// ssh:// file:// etc.
-((((git|hg)\\+)?(git|ssh|file|https?):(//)?)
+((((git|hg)\+)?(git|ssh|file|https?):(//)?)
| # or
- (\\w+@[\\w\\.]+) # something like user@...
+ (\w+@[\w\.]+) # something like user@...
+)
+""",
+ re.VERBOSE,
)
-"""
- , re.VERBOSE)
def is_repo_url(value):
"""Return True if value is a repository URL."""
- pass
+ return bool(REPO_REGEX.match(value))
def is_zip_file(value):
"""Return True if value is a zip file."""
- pass
+ return value.lower().endswith('.zip')
def expand_abbreviations(template, abbreviations):
@@ -31,7 +35,16 @@ def expand_abbreviations(template, abbreviations):
:param template: The project template name.
:param abbreviations: Abbreviation definitions.
"""
- pass
+ if template in abbreviations:
+ return abbreviations[template]
+
+ # Split on colon. If there is no colon, rest will be empty
+ # and prefix will be the whole template
+ prefix, sep, rest = template.partition(':')
+ if prefix in abbreviations:
+ return abbreviations[prefix].format(rest)
+
+ return template
def repository_has_cookiecutter_json(repo_directory):
@@ -40,11 +53,23 @@ def repository_has_cookiecutter_json(repo_directory):
:param repo_directory: The candidate repository directory.
:return: True if the `repo_directory` is valid, else False.
"""
- pass
+ repo_directory_exists = os.path.isdir(repo_directory)
+
+ repo_config_exists = os.path.isfile(
+ os.path.join(repo_directory, 'cookiecutter.json')
+ )
+ return repo_directory_exists and repo_config_exists
-def determine_repo_dir(template, abbreviations, clone_to_dir, checkout,
- no_input, password=None, directory=None):
+def determine_repo_dir(
+ template,
+ abbreviations,
+ clone_to_dir,
+ checkout,
+ no_input,
+ password=None,
+ directory=None,
+):
"""
Locate the repository directory from a template reference.
@@ -67,4 +92,41 @@ def determine_repo_dir(template, abbreviations, clone_to_dir, checkout,
after the template has been instantiated.
:raises: `RepositoryNotFound` if a repository directory could not be found.
"""
- pass
+ template = expand_abbreviations(template, abbreviations)
+
+ if is_zip_file(template):
+ unzipped_dir = unzip(
+ zip_uri=template,
+ is_url=is_repo_url(template),
+ clone_to_dir=clone_to_dir,
+ no_input=no_input,
+ password=password,
+ )
+ repository_candidates = [unzipped_dir]
+ cleanup = True
+ elif is_repo_url(template):
+ cloned_repo = clone(
+ repo_url=template,
+ checkout=checkout,
+ clone_to_dir=clone_to_dir,
+ no_input=no_input,
+ )
+ repository_candidates = [cloned_repo]
+ cleanup = False
+ else:
+ repository_candidates = [template, os.path.join(clone_to_dir, template)]
+ cleanup = False
+
+ if directory:
+ repository_candidates = [
+ os.path.join(s, directory) for s in repository_candidates
+ ]
+
+ for repo_candidate in repository_candidates:
+ if repository_has_cookiecutter_json(repo_candidate):
+ return repo_candidate, cleanup
+
+ raise RepositoryNotFound(
+ 'A valid repository for "{}" could not be found in the following '
+ 'locations:\n{}'.format(template, '\n'.join(repository_candidates))
+ )
diff --git a/cookiecutter/utils.py b/cookiecutter/utils.py
index 6aa68ba..b21252c 100644
--- a/cookiecutter/utils.py
+++ b/cookiecutter/utils.py
@@ -1,4 +1,5 @@
"""Helper functions used throughout Cookiecutter."""
+
import contextlib
import logging
import os
@@ -7,8 +8,11 @@ import stat
import tempfile
from pathlib import Path
from typing import Dict
+
from jinja2.ext import Extension
+
from cookiecutter.environment import StrictEnvironment
+
logger = logging.getLogger(__name__)
@@ -18,7 +22,8 @@ def force_delete(func, path, exc_info):
Usage: `shutil.rmtree(path, onerror=force_delete)`
From https://docs.python.org/3/library/shutil.html#rmtree-example
"""
- pass
+ os.chmod(path, stat.S_IWRITE)
+ func(path)
def rmtree(path):
@@ -26,15 +31,19 @@ def rmtree(path):
:param path: A directory path.
"""
- pass
+ shutil.rmtree(path, onerror=force_delete)
-def make_sure_path_exists(path: 'os.PathLike[str]') ->None:
+def make_sure_path_exists(path: "os.PathLike[str]") -> None:
"""Ensure that a directory exists.
:param path: A directory tree path for creation.
"""
- pass
+ logger.debug('Making sure path exists (creates tree if not exist): %s', path)
+ try:
+ Path(path).mkdir(parents=True, exist_ok=True)
+ except OSError as error:
+ raise OSError(f'Unable to create directory at {path}') from error
@contextlib.contextmanager
@@ -43,7 +52,13 @@ def work_in(dirname=None):
When exited, returns to the working directory prior to entering.
"""
- pass
+ curdir = os.getcwd()
+ try:
+ if dirname is not None:
+ os.chdir(dirname)
+ yield
+ finally:
+ os.chdir(curdir)
def make_executable(script_path):
@@ -51,19 +66,34 @@ def make_executable(script_path):
:param script_path: The file to change
"""
- pass
+ status = os.stat(script_path)
+ os.chmod(script_path, status.st_mode | stat.S_IEXEC)
def simple_filter(filter_function):
"""Decorate a function to wrap it in a simplified jinja2 extension."""
- pass
+ class SimpleFilterExtension(Extension):
+ def __init__(self, environment):
+ super().__init__(environment)
+ environment.filters[filter_function.__name__] = filter_function
-def create_tmp_repo_dir(repo_dir: 'os.PathLike[str]') ->Path:
+ SimpleFilterExtension.__name__ = filter_function.__name__
+ return SimpleFilterExtension
+
+
+def create_tmp_repo_dir(repo_dir: "os.PathLike[str]") -> Path:
"""Create a temporary dir with a copy of the contents of repo_dir."""
- pass
+ repo_dir = Path(repo_dir).resolve()
+ base_dir = tempfile.mkdtemp(prefix='cookiecutter')
+ new_dir = f"{base_dir}/{repo_dir.name}"
+ logger.debug(f'Copying repo_dir from {repo_dir} to {new_dir}')
+ shutil.copytree(repo_dir, new_dir)
+ return Path(new_dir)
def create_env_with_context(context: Dict):
"""Create a jinja environment using the provided context."""
- pass
+ envvars = context.get('cookiecutter', {}).get('_jinja2_env_vars', {})
+
+ return StrictEnvironment(context=context, keep_trailing_newline=True, **envvars)
diff --git a/cookiecutter/vcs.py b/cookiecutter/vcs.py
index 94d6c05..db57ae9 100644
--- a/cookiecutter/vcs.py
+++ b/cookiecutter/vcs.py
@@ -1,15 +1,28 @@
"""Helper functions for working with version control systems."""
+
import logging
import os
-import subprocess
+import subprocess # nosec
from pathlib import Path
from shutil import which
from typing import Optional
-from cookiecutter.exceptions import RepositoryCloneFailed, RepositoryNotFound, UnknownRepoType, VCSNotInstalled
+
+from cookiecutter.exceptions import (
+ RepositoryCloneFailed,
+ RepositoryNotFound,
+ UnknownRepoType,
+ VCSNotInstalled,
+)
from cookiecutter.prompt import prompt_and_delete
from cookiecutter.utils import make_sure_path_exists
+
logger = logging.getLogger(__name__)
-BRANCH_ERRORS = ['error: pathspec', 'unknown revision']
+
+
+BRANCH_ERRORS = [
+ 'error: pathspec',
+ 'unknown revision',
+]
def identify_repo(repo_url):
@@ -20,7 +33,20 @@ def identify_repo(repo_url):
:param repo_url: Repo URL of unknown type.
:returns: ('git', repo_url), ('hg', repo_url), or None.
"""
- pass
+ repo_url_values = repo_url.split('+')
+ if len(repo_url_values) == 2:
+ repo_type = repo_url_values[0]
+ if repo_type in ["git", "hg"]:
+ return repo_type, repo_url_values[1]
+ else:
+ raise UnknownRepoType
+ else:
+ if 'git' in repo_url:
+ return 'git', repo_url
+ elif 'bitbucket' in repo_url:
+ return 'hg', repo_url
+ else:
+ raise UnknownRepoType
def is_vcs_installed(repo_type):
@@ -29,11 +55,15 @@ def is_vcs_installed(repo_type):
:param repo_type:
"""
- pass
+ return bool(which(repo_type))
-def clone(repo_url: str, checkout: Optional[str]=None, clone_to_dir:
- 'os.PathLike[str]'='.', no_input: bool=False):
+def clone(
+ repo_url: str,
+ checkout: Optional[str] = None,
+ clone_to_dir: "os.PathLike[str]" = ".",
+ no_input: bool = False,
+):
"""Clone a repo to the current directory.
:param repo_url: Repo URL of unknown type.
@@ -44,4 +74,62 @@ def clone(repo_url: str, checkout: Optional[str]=None, clone_to_dir:
cached resources.
:returns: str with path to the new directory of the repository.
"""
- pass
+ # Ensure that clone_to_dir exists
+ clone_to_dir = Path(clone_to_dir).expanduser()
+ make_sure_path_exists(clone_to_dir)
+
+ # identify the repo_type
+ repo_type, repo_url = identify_repo(repo_url)
+
+ # check that the appropriate VCS for the repo_type is installed
+ if not is_vcs_installed(repo_type):
+ msg = f"'{repo_type}' is not installed."
+ raise VCSNotInstalled(msg)
+
+ repo_url = repo_url.rstrip('/')
+ repo_name = os.path.split(repo_url)[1]
+ if repo_type == 'git':
+ repo_name = repo_name.split(':')[-1].rsplit('.git')[0]
+ repo_dir = os.path.normpath(os.path.join(clone_to_dir, repo_name))
+ if repo_type == 'hg':
+ repo_dir = os.path.normpath(os.path.join(clone_to_dir, repo_name))
+ logger.debug(f'repo_dir is {repo_dir}')
+
+ if os.path.isdir(repo_dir):
+ clone = prompt_and_delete(repo_dir, no_input=no_input)
+ else:
+ clone = True
+
+ if clone:
+ try:
+ subprocess.check_output( # nosec
+ [repo_type, 'clone', repo_url],
+ cwd=clone_to_dir,
+ stderr=subprocess.STDOUT,
+ )
+ if checkout is not None:
+ checkout_params = [checkout]
+ # Avoid Mercurial "--config" and "--debugger" injection vulnerability
+ if repo_type == "hg":
+ checkout_params.insert(0, "--")
+ subprocess.check_output( # nosec
+ [repo_type, 'checkout', *checkout_params],
+ cwd=repo_dir,
+ stderr=subprocess.STDOUT,
+ )
+ except subprocess.CalledProcessError as clone_error:
+ output = clone_error.output.decode('utf-8')
+ if 'not found' in output.lower():
+ raise RepositoryNotFound(
+ f'The repository {repo_url} could not be found, '
+ 'have you made a typo?'
+ ) from clone_error
+ if any(error in output for error in BRANCH_ERRORS):
+ raise RepositoryCloneFailed(
+ f'The {checkout} branch of repository '
+ f'{repo_url} could not found, have you made a typo?'
+ ) from clone_error
+ logger.error('git clone failed with error: %s', output)
+ raise
+
+ return repo_dir
diff --git a/cookiecutter/zipfile.py b/cookiecutter/zipfile.py
index c4d398a..e3cfb3e 100644
--- a/cookiecutter/zipfile.py
+++ b/cookiecutter/zipfile.py
@@ -1,17 +1,25 @@
"""Utility functions for handling and fetching repo archives in zip format."""
+
import os
import tempfile
from pathlib import Path
from typing import Optional
from zipfile import BadZipFile, ZipFile
+
import requests
+
from cookiecutter.exceptions import InvalidZipRepository
from cookiecutter.prompt import prompt_and_delete, read_repo_password
from cookiecutter.utils import make_sure_path_exists
-def unzip(zip_uri: str, is_url: bool, clone_to_dir: 'os.PathLike[str]'='.',
- no_input: bool=False, password: Optional[str]=None):
+def unzip(
+ zip_uri: str,
+ is_url: bool,
+ clone_to_dir: "os.PathLike[str]" = ".",
+ no_input: bool = False,
+ password: Optional[str] = None,
+):
"""Download and unpack a zipfile at a given URI.
This will download the zipfile to the cookiecutter repository,
@@ -25,4 +33,89 @@ def unzip(zip_uri: str, is_url: bool, clone_to_dir: 'os.PathLike[str]'='.',
cached resources.
:param password: The password to use when unpacking the repository.
"""
- pass
+ # Ensure that clone_to_dir exists
+ clone_to_dir = Path(clone_to_dir).expanduser()
+ make_sure_path_exists(clone_to_dir)
+
+ if is_url:
+ # Build the name of the cached zipfile,
+ # and prompt to delete if it already exists.
+ identifier = zip_uri.rsplit('/', 1)[1]
+ zip_path = os.path.join(clone_to_dir, identifier)
+
+ if os.path.exists(zip_path):
+ download = prompt_and_delete(zip_path, no_input=no_input)
+ else:
+ download = True
+
+ if download:
+ # (Re) download the zipfile
+ r = requests.get(zip_uri, stream=True, timeout=100)
+ with open(zip_path, 'wb') as f:
+ for chunk in r.iter_content(chunk_size=1024):
+ if chunk: # filter out keep-alive new chunks
+ f.write(chunk)
+ else:
+ # Just use the local zipfile as-is.
+ zip_path = os.path.abspath(zip_uri)
+
+ # Now unpack the repository. The zipfile will be unpacked
+ # into a temporary directory
+ try:
+ zip_file = ZipFile(zip_path)
+
+ if len(zip_file.namelist()) == 0:
+ raise InvalidZipRepository(f'Zip repository {zip_uri} is empty')
+
+ # The first record in the zipfile should be the directory entry for
+ # the archive. If it isn't a directory, there's a problem.
+ first_filename = zip_file.namelist()[0]
+ if not first_filename.endswith('/'):
+ raise InvalidZipRepository(
+ f"Zip repository {zip_uri} does not include a top-level directory"
+ )
+
+ # Construct the final target directory
+ project_name = first_filename[:-1]
+ unzip_base = tempfile.mkdtemp()
+ unzip_path = os.path.join(unzip_base, project_name)
+
+ # Extract the zip file into the temporary directory
+ try:
+ zip_file.extractall(path=unzip_base)
+ except RuntimeError:
+ # File is password protected; try to get a password from the
+ # environment; if that doesn't work, ask the user.
+ if password is not None:
+ try:
+ zip_file.extractall(path=unzip_base, pwd=password.encode('utf-8'))
+ except RuntimeError:
+ raise InvalidZipRepository(
+ 'Invalid password provided for protected repository'
+ )
+ elif no_input:
+ raise InvalidZipRepository(
+ 'Unable to unlock password protected repository'
+ )
+ else:
+ retry = 0
+ while retry is not None:
+ try:
+ password = read_repo_password('Repo password')
+ zip_file.extractall(
+ path=unzip_base, pwd=password.encode('utf-8')
+ )
+ retry = None
+ except RuntimeError:
+ retry += 1
+ if retry == 3:
+ raise InvalidZipRepository(
+ 'Invalid password provided for protected repository'
+ )
+
+ except BadZipFile:
+ raise InvalidZipRepository(
+ f'Zip repository {zip_uri} is not a valid zip archive:'
+ )
+
+ return unzip_path