Compare commits
4 Commits
test-cover
...
nginx-refa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fcc76618fa | ||
|
|
7a79c55af8 | ||
|
|
e52cd73b84 | ||
|
|
6c6dd3dd1a |
172
certbot-nginx/certbot_nginx/nginx_parser_obj.py
Normal file
172
certbot-nginx/certbot_nginx/nginx_parser_obj.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""NginxParser is a member object of the NginxConfigurator class."""
|
||||
import glob
|
||||
import logging
|
||||
import pyparsing
|
||||
|
||||
from certbot import errors
|
||||
from certbot.compat import os
|
||||
|
||||
from certbot_nginx import nginxparser
|
||||
from certbot_nginx import parser_obj as obj
|
||||
from certbot_nginx import obj as nginx_obj
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class NginxParseContext(obj.ParseContext):
|
||||
""" A parsing context which includes a set of parsing hooks specific to Nginx
|
||||
configuration files. """
|
||||
def __init__(self, parent=None, filename=None, cwd=None, parsed_files=None):
|
||||
super(NginxParseContext, self).__init__(parent, filename, cwd)
|
||||
self.parsed_files = parsed_files if parsed_files else {}
|
||||
|
||||
@staticmethod
|
||||
def parsing_hooks():
|
||||
return NGINX_PARSING_HOOKS
|
||||
|
||||
def child(self, parent, filename=None):
|
||||
return NginxParseContext(parent,filename if filename else self.filename,
|
||||
self.cwd, self.parsed_files)
|
||||
|
||||
def parse_from_file_nginx(context):
|
||||
""" Parses from a file specified by `context`.
|
||||
:param NginxParseContext context:
|
||||
:returns WithLists:
|
||||
"""
|
||||
raw_parsed = []
|
||||
with open(os.path.join(context.cwd, context.filename)) as _file:
|
||||
try:
|
||||
raw_parsed = nginxparser.load(_file, True)
|
||||
except pyparsing.ParseException as err:
|
||||
logger.debug("Could not parse file: %s due to %s", context.filename, err)
|
||||
context.parsed_files[context.filename] = None
|
||||
parsed = obj.parse_raw(raw_parsed, context=context)
|
||||
parsed.context.parsed_files[context.filename] = parsed
|
||||
return parsed
|
||||
|
||||
class Include(obj.Sentence):
|
||||
""" Represents an include statement. On parsing, tries to read and parse included file(s), while
|
||||
avoiding duplicates from `context.parsed`."""
|
||||
def __init__(self, context):
|
||||
super(Include, self).__init__(context)
|
||||
self.parsed = dict()
|
||||
|
||||
@staticmethod
|
||||
def should_parse(lists):
|
||||
return obj.Sentence.should_parse(lists) and "include" in lists
|
||||
|
||||
def parse(self, raw_list, add_spaces=False):
|
||||
""" Parsing an include this will try to fetch the associated files (if they exist)
|
||||
and parses them all. Any parsed files are added to the global context.parsed_files object.
|
||||
"""
|
||||
super(Include, self).parse(raw_list, add_spaces)
|
||||
filepath = self.filename
|
||||
if not os.path.isabs(filepath):
|
||||
filepath = os.path.join(self.context.cwd, self.filename)
|
||||
for f in glob.glob(filepath):
|
||||
self.parsed[f] = self.context.parsed_files[f] if f in self.context.parsed_files else \
|
||||
parse_from_file_nginx(self.child_context(f))
|
||||
|
||||
@property
|
||||
def filename(self):
|
||||
""" Retrieves the filename that is being included. """
|
||||
return self.words[1]
|
||||
|
||||
def iterate(self, expanded=False, match=None):
|
||||
""" Iterates itself, and if expanded is set, iterates over the `Directives` objects
|
||||
in all of the included files.
|
||||
"""
|
||||
if match is None or match(self):
|
||||
yield self
|
||||
if expanded:
|
||||
for parsed in self.parsed.values():
|
||||
for sub_elem in parsed.iterate(expanded, match):
|
||||
yield sub_elem
|
||||
|
||||
class ServerBlock(obj.Block):
|
||||
""" Parsing object which represents an Nginx server block.
|
||||
This bloc should parallel a "VirtualHost" object-- any update or modification should
|
||||
also update the corresponding virtual host object. """
|
||||
|
||||
REPEATABLE_DIRECTIVES = set(['server_name', 'listen', 'include', 'rewrite', 'add_header'])
|
||||
|
||||
def __init__(self, context=None):
|
||||
super(ServerBlock, self).__init__(context)
|
||||
self.vhost = None
|
||||
self.addrs = set()
|
||||
self.ssl = False
|
||||
self.server_names = set()
|
||||
|
||||
@staticmethod
|
||||
def should_parse(lists):
|
||||
return obj.Block.should_parse(lists) and "server" in lists[0]
|
||||
|
||||
def _update_vhost(self):
|
||||
# copied from _parse_server_raw
|
||||
self.addrs = set()
|
||||
self.ssl = False
|
||||
self.server_names = set()
|
||||
apply_ssl_to_all_addrs = False
|
||||
for directive in self.contents.get_type(obj.Sentence):
|
||||
if len(directive.words) == 0:
|
||||
continue
|
||||
if directive[0] == 'listen':
|
||||
addr = nginx_obj.Addr.fromstring(" ".join(directive[1:]))
|
||||
if addr:
|
||||
self.addrs.add(addr)
|
||||
if addr.ssl:
|
||||
self.ssl = True
|
||||
if directive[0] == 'server_name':
|
||||
self.server_names.update(x.strip('"\'') for x in directive[1:])
|
||||
for ssl in self.get_directives('ssl'):
|
||||
if ssl[1] == "on":
|
||||
self.ssl = True
|
||||
apply_ssl_to_all_addrs = True
|
||||
if apply_ssl_to_all_addrs:
|
||||
for addr in self.addrs:
|
||||
addr.ssl = True
|
||||
self.vhost.addrs = self.addrs
|
||||
self.vhost.names = self.server_names
|
||||
self.vhost.ssl = self.ssl
|
||||
self.vhost.raw = self.dump_unspaced_list()[1]
|
||||
self.vhost.raw_obj = self
|
||||
|
||||
def add_directive(self, raw_list, insert_at_top=False):
|
||||
""" Adds a single directive to this Server Block's contents, while enforcing
|
||||
repeatability rules."""
|
||||
statement = obj.parse_raw(raw_list, self.contents.child_context(), add_spaces=False)
|
||||
if isinstance(statement, obj.Sentence) and statement[0] not in self.REPEATABLE_DIRECTIVES \
|
||||
and len(list(self.get_directives(statement[0]))) > 0:
|
||||
raise errors.MisconfigurationError(
|
||||
"Existing %s directive conflicts with %s" % (statement[0], statement))
|
||||
self.contents.add_directive(statement, insert_at_top)
|
||||
|
||||
def update_or_add_directive(self, raw_list, insert_at_top=False):
|
||||
""" Adds a single directive to this Server Block's contents, while enforcing
|
||||
repeatability rules."""
|
||||
statement = obj.parse_raw(raw_list, self.contents.child_context(), add_spaces=False)
|
||||
index = self.contents.find_directive(lambda elem: elem[0] == statement[0])
|
||||
if index < 0:
|
||||
self.contents.add_directive(statement, insert_at_top)
|
||||
return
|
||||
self.contents.update_directive(statement, index)
|
||||
|
||||
def get_directives(self, name, match=None):
|
||||
""" Retrieves any child directive starting with `name`.
|
||||
:param str name: The directive name to fetch.
|
||||
:param callable match: An additional optional filter to specify matching directives.
|
||||
:return: an iterator over matching directives.
|
||||
"""
|
||||
directives = self.contents.get_type(obj.Sentence)
|
||||
return [d for d in directives if len(d) > 0 and d[0] == name and (match is None or match(d))]
|
||||
|
||||
def parse(self, raw_list, add_spaces=False):
|
||||
""" Parses lists into a ServerBlock object, and creates a
|
||||
corresponding VirtualHost metadata object. """
|
||||
super(ServerBlock, self).parse(raw_list, add_spaces)
|
||||
self.vhost = nginx_obj.VirtualHost(
|
||||
self.context.filename if self.context is not None else "",
|
||||
self.addrs, self.ssl, True, self.server_names, self.dump_unspaced_list()[1],
|
||||
self.get_path(), self)
|
||||
self._update_vhost()
|
||||
|
||||
NGINX_PARSING_HOOKS = (ServerBlock, obj.Block, Include, obj.Sentence, obj.Directives)
|
||||
@@ -101,36 +101,42 @@ class RawNginxDumper(object):
|
||||
# Shortcut functions to respect Python's serialization interface
|
||||
# (like pyyaml, picker or json)
|
||||
|
||||
def loads(source):
|
||||
def loads(source, raw=False):
|
||||
"""Parses from a string.
|
||||
|
||||
:param str source: The string to parse
|
||||
:param bool raw: If true, doesn't return an UnspacedList.
|
||||
:returns: The parsed tree
|
||||
:rtype: list
|
||||
|
||||
"""
|
||||
if raw:
|
||||
return RawNginxParser(source).as_list()
|
||||
return UnspacedList(RawNginxParser(source).as_list())
|
||||
|
||||
|
||||
def load(_file):
|
||||
def load(_file, raw=False):
|
||||
"""Parses from a file.
|
||||
|
||||
:param file _file: The file to parse
|
||||
:param bool raw: If true, doesn't return an UnspacedList.
|
||||
:returns: The parsed tree
|
||||
:rtype: list
|
||||
|
||||
"""
|
||||
return loads(_file.read())
|
||||
return loads(_file.read(), raw)
|
||||
|
||||
|
||||
def dumps(blocks):
|
||||
def dumps(blocks, raw=False):
|
||||
"""Dump to a string.
|
||||
|
||||
:param UnspacedList block: The parsed tree
|
||||
:param int indentation: The number of spaces to indent
|
||||
:param UnspacedList or list block: The parsed tree
|
||||
:param bool raw: If true, expects a regular list, not UnspacedList.
|
||||
:rtype: str
|
||||
|
||||
"""
|
||||
if raw:
|
||||
return str(RawNginxDumper(blocks))
|
||||
return str(RawNginxDumper(blocks.spaced))
|
||||
|
||||
|
||||
@@ -189,6 +195,8 @@ class UnspacedList(list):
|
||||
def insert(self, i, x):
|
||||
item, spaced_item = self._coerce(x)
|
||||
slicepos = self._spaced_position(i) if i < len(self) else len(self.spaced)
|
||||
if slicepos > 0 and spacey(self.spaced[slicepos-1]):
|
||||
slicepos -= 1
|
||||
self.spaced.insert(slicepos, spaced_item)
|
||||
if not spacey(item):
|
||||
list.insert(self, i, item)
|
||||
@@ -196,14 +204,20 @@ class UnspacedList(list):
|
||||
|
||||
def append(self, x):
|
||||
item, spaced_item = self._coerce(x)
|
||||
self.spaced.append(spaced_item)
|
||||
if len(self.spaced) > 0 and spacey(self.spaced[-1]):
|
||||
self.spaced.insert(-1, spaced_item)
|
||||
else:
|
||||
self.spaced.append(spaced_item)
|
||||
if not spacey(item):
|
||||
list.append(self, item)
|
||||
self.dirty = True
|
||||
|
||||
def extend(self, x):
|
||||
item, spaced_item = self._coerce(x)
|
||||
self.spaced.extend(spaced_item)
|
||||
if len(self.spaced) > 0 and spacey(self.spaced[-1]):
|
||||
self.spaced[-1:-1] = spaced_item
|
||||
else:
|
||||
self.spaced.extend(spaced_item)
|
||||
list.extend(self, item)
|
||||
self.dirty = True
|
||||
|
||||
|
||||
@@ -161,7 +161,7 @@ class VirtualHost(object): # pylint: disable=too-few-public-methods
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, filep, addrs, ssl, enabled, names, raw, path):
|
||||
def __init__(self, filep, addrs, ssl, enabled, names, raw, path, raw_obj=None):
|
||||
# pylint: disable=too-many-arguments
|
||||
"""Initialize a VH."""
|
||||
self.filep = filep
|
||||
@@ -171,6 +171,7 @@ class VirtualHost(object): # pylint: disable=too-few-public-methods
|
||||
self.enabled = enabled
|
||||
self.raw = raw
|
||||
self.path = path
|
||||
self.raw_obj = raw_obj
|
||||
|
||||
def __str__(self):
|
||||
addr_str = ", ".join(str(addr) for addr in sorted(self.addrs, key=str))
|
||||
|
||||
@@ -12,6 +12,7 @@ from certbot import errors
|
||||
from certbot.compat import os
|
||||
|
||||
from certbot_nginx import obj
|
||||
from certbot_nginx import nginx_parser_obj as nginx_obj
|
||||
from certbot_nginx import nginxparser
|
||||
from acme.magic_typing import Union, Dict, Set, Any, List, Tuple # pylint: disable=unused-import, no-name-in-module
|
||||
|
||||
@@ -32,6 +33,8 @@ class NginxParser(object):
|
||||
self.root = os.path.abspath(root)
|
||||
self.config_root = self._find_config_root()
|
||||
|
||||
self.parsed_root = None
|
||||
|
||||
# Parse nginx.conf and included files.
|
||||
# TODO: Check sites-available/ as well. For now, the configurator does
|
||||
# not enable sites from there.
|
||||
@@ -43,6 +46,10 @@ class NginxParser(object):
|
||||
"""
|
||||
self.parsed = {}
|
||||
self._parse_recursively(self.config_root)
|
||||
if self.parsed_root:
|
||||
return
|
||||
self.parsed_root = nginx_obj.parse_from_file_nginx(
|
||||
nginx_obj.NginxParseContext(cwd=self.root, filename=self.config_root))
|
||||
|
||||
def _parse_recursively(self, filepath):
|
||||
"""Parses nginx config files recursively by looking at 'include'
|
||||
@@ -88,18 +95,14 @@ class NginxParser(object):
|
||||
def _build_addr_to_ssl(self):
|
||||
"""Builds a map from address to whether it listens on ssl in any server block
|
||||
"""
|
||||
servers = self._get_raw_servers()
|
||||
|
||||
addr_to_ssl = {} # type: Dict[Tuple[str, str], bool]
|
||||
for filename in servers:
|
||||
for server, _ in servers[filename]:
|
||||
# Parse the server block to save addr info
|
||||
parsed_server = _parse_server_raw(server)
|
||||
for addr in parsed_server['addrs']:
|
||||
addr_tuple = addr.normalized_tuple()
|
||||
if addr_tuple not in addr_to_ssl:
|
||||
addr_to_ssl[addr_tuple] = addr.ssl
|
||||
addr_to_ssl[addr_tuple] = addr.ssl or addr_to_ssl[addr_tuple]
|
||||
servers = self.parsed_root.get_type(nginx_obj.ServerBlock)
|
||||
addr_to_ssl = {}
|
||||
for server in servers:
|
||||
for addr in server.vhost.addrs:
|
||||
addr_tuple = addr.normalized_tuple()
|
||||
if addr_tuple not in addr_to_ssl:
|
||||
addr_to_ssl[addr_tuple] = addr.ssl
|
||||
addr_to_ssl[addr_tuple] = addr.ssl or addr_to_ssl[addr_tuple]
|
||||
return addr_to_ssl
|
||||
|
||||
def _get_raw_servers(self):
|
||||
@@ -117,12 +120,20 @@ class NginxParser(object):
|
||||
_do_for_subarray(tree, lambda x: len(x) >= 2 and x[0] == ['server'],
|
||||
lambda x, y: srv.append((x[1], y)))
|
||||
|
||||
# Find 'include' statements in server blocks and append their trees
|
||||
# Find 'include statement in server blocks and append their trees
|
||||
for i, (server, path) in enumerate(servers[filename]):
|
||||
new_server = self._get_included_directives(server)
|
||||
servers[filename][i] = (new_server, path)
|
||||
return servers
|
||||
|
||||
def get_vhost(self, filename):
|
||||
""" Retrieve first found vhost in file for testing purposes. """
|
||||
blocks = self.parsed_root.get_type(nginx_obj.ServerBlock)
|
||||
for server_block in blocks:
|
||||
if server_block.vhost.filep == filename:
|
||||
return server_block.vhost
|
||||
return None
|
||||
|
||||
def get_vhosts(self):
|
||||
# pylint: disable=cell-var-from-loop
|
||||
"""Gets list of all 'virtual hosts' found in Nginx configuration.
|
||||
@@ -134,26 +145,11 @@ class NginxParser(object):
|
||||
:rtype: list
|
||||
|
||||
"""
|
||||
enabled = True # We only look at enabled vhosts for now
|
||||
servers = self._get_raw_servers()
|
||||
|
||||
vhosts = []
|
||||
for filename in servers:
|
||||
for server, path in servers[filename]:
|
||||
# Parse the server block into a VirtualHost object
|
||||
|
||||
parsed_server = _parse_server_raw(server)
|
||||
vhost = obj.VirtualHost(filename,
|
||||
parsed_server['addrs'],
|
||||
parsed_server['ssl'],
|
||||
enabled,
|
||||
parsed_server['names'],
|
||||
server,
|
||||
path)
|
||||
vhosts.append(vhost)
|
||||
|
||||
blocks = self.parsed_root.get_type(nginx_obj.ServerBlock)
|
||||
for server_block in blocks:
|
||||
vhosts.append(server_block.vhost)
|
||||
self._update_vhosts_addrs_ssl(vhosts)
|
||||
|
||||
return vhosts
|
||||
|
||||
def _update_vhosts_addrs_ssl(self, vhosts):
|
||||
@@ -295,8 +291,18 @@ class NginxParser(object):
|
||||
of the server block instead of the bottom
|
||||
|
||||
"""
|
||||
self._modify_server_directives(vhost,
|
||||
functools.partial(_add_directives, directives, insert_at_top))
|
||||
|
||||
for directive in directives:
|
||||
if not _is_whitespace_or_empty(directive):
|
||||
vhost.raw_obj.add_directive(directive, insert_at_top)
|
||||
self._sync_old_structs(vhost)
|
||||
|
||||
def _sync_old_structs(self, vhost):
|
||||
vhost.raw_obj._update_vhost()
|
||||
old = self.parsed[vhost.filep]
|
||||
for index in vhost.path[:-1]:
|
||||
old = old[index]
|
||||
old[vhost.path[-1]] = vhost.raw_obj.dump_unspaced_list()
|
||||
|
||||
def update_or_add_server_directives(self, vhost, directives, insert_at_top=False):
|
||||
"""Add or replace directives in the server block identified by vhost.
|
||||
@@ -317,8 +323,10 @@ class NginxParser(object):
|
||||
of the server block instead of the bottom
|
||||
|
||||
"""
|
||||
self._modify_server_directives(vhost,
|
||||
functools.partial(_update_or_add_directives, directives, insert_at_top))
|
||||
for directive in directives:
|
||||
if not _is_whitespace_or_empty(directive):
|
||||
vhost.raw_obj.update_or_add_directive(directive, insert_at_top)
|
||||
self._sync_old_structs(vhost)
|
||||
|
||||
def remove_server_directives(self, vhost, directive_name, match_func=None):
|
||||
"""Remove all directives of type directive_name.
|
||||
@@ -339,6 +347,8 @@ class NginxParser(object):
|
||||
vhost.ssl = parsed_server['ssl']
|
||||
vhost.names = parsed_server['names']
|
||||
vhost.raw = new_server
|
||||
if vhost.raw_obj:
|
||||
vhost.raw_obj.contents.parse(vhost.raw.spaced)
|
||||
|
||||
def _modify_server_directives(self, vhost, block_func):
|
||||
filename = vhost.filep
|
||||
@@ -374,8 +384,10 @@ class NginxParser(object):
|
||||
new_vhost = copy.deepcopy(vhost_template)
|
||||
|
||||
enclosing_block = self.parsed[vhost_template.filep]
|
||||
parsing_obj = self.parsed_root.context.parsed_files[vhost_template.filep]
|
||||
for index in vhost_template.path[:-1]:
|
||||
enclosing_block = enclosing_block[index]
|
||||
parsing_obj = parsing_obj._data[index]
|
||||
raw_in_parsed = copy.deepcopy(enclosing_block[vhost_template.path[-1]])
|
||||
|
||||
if only_directives is not None:
|
||||
@@ -383,11 +395,13 @@ class NginxParser(object):
|
||||
for directive in raw_in_parsed[1]:
|
||||
if directive and directive[0] in only_directives:
|
||||
new_directives.append(directive)
|
||||
new_directives.append("\n")
|
||||
raw_in_parsed[1] = new_directives
|
||||
|
||||
self._update_vhost_based_on_new_directives(new_vhost, new_directives)
|
||||
|
||||
enclosing_block.append(raw_in_parsed)
|
||||
if "\n" not in enclosing_block[-1][0][0]:
|
||||
enclosing_block[-1][0].insert(0, "\n")
|
||||
new_vhost.path[-1] = len(enclosing_block) - 1
|
||||
if remove_singleton_listen_params:
|
||||
for addr in new_vhost.addrs:
|
||||
@@ -406,6 +420,11 @@ class NginxParser(object):
|
||||
keys = [x.split('=')[0] for x in directive]
|
||||
if param in keys:
|
||||
del directive[keys.index(param)]
|
||||
og_vhost_path = vhost_template.raw_obj.get_path()[-1]
|
||||
parsing_obj.parse(enclosing_block.spaced)
|
||||
new_vhost = parsing_obj._data[new_vhost.path[-1]].vhost
|
||||
vhost_template.raw_obj = parsing_obj._data[og_vhost_path]
|
||||
vhost_template.raw_obj.vhost = vhost_template
|
||||
return new_vhost
|
||||
|
||||
|
||||
@@ -552,18 +571,11 @@ def _is_ssl_on_directive(entry):
|
||||
len(entry) == 2 and entry[0] == 'ssl' and
|
||||
entry[1] == 'on')
|
||||
|
||||
def _add_directives(directives, insert_at_top, block):
|
||||
"""Adds directives to a config block."""
|
||||
for directive in directives:
|
||||
_add_directive(block, directive, insert_at_top)
|
||||
if block and '\n' not in block[-1]: # could be " \n " or ["\n"] !
|
||||
block.append(nginxparser.UnspacedList('\n'))
|
||||
|
||||
def _update_or_add_directives(directives, insert_at_top, block):
|
||||
"""Adds or replaces directives in a config block."""
|
||||
for directive in directives:
|
||||
_update_or_add_directive(block, directive, insert_at_top)
|
||||
if block and '\n' not in block[-1]: # could be " \n " or ["\n"] !
|
||||
if block and '\n' not in block.spaced[-1]: # could be " \n " or ["\n"] !
|
||||
block.append(nginxparser.UnspacedList('\n'))
|
||||
|
||||
|
||||
@@ -588,8 +600,6 @@ def comment_directive(block, location):
|
||||
next_entry = next_entry[0]
|
||||
|
||||
block.insert(location + 1, COMMENT_BLOCK[:])
|
||||
if next_entry is not None and "\n" not in next_entry:
|
||||
block.insert(location + 2, '\n')
|
||||
|
||||
def _comment_out_directive(block, location, include_location):
|
||||
"""Comment out the line at location, with a note of explanation."""
|
||||
@@ -624,6 +634,7 @@ def _is_whitespace_or_comment(directive):
|
||||
"""Is this directive either a whitespace or comment directive?"""
|
||||
return len(directive) == 0 or directive[0] == '#'
|
||||
|
||||
# block = Statements
|
||||
def _add_directive(block, directive, insert_at_top):
|
||||
if not isinstance(directive, nginxparser.UnspacedList):
|
||||
directive = nginxparser.UnspacedList(directive)
|
||||
@@ -723,6 +734,13 @@ def _apply_global_addr_ssl(addr_to_ssl, parsed_server):
|
||||
if addr.ssl:
|
||||
parsed_server['ssl'] = True
|
||||
|
||||
def _is_whitespace_or_empty(directive):
|
||||
if not directive:
|
||||
return True
|
||||
if len(directive) != 1:
|
||||
return False
|
||||
return len(directive[0]) == 0 or directive[0].isspace()
|
||||
|
||||
def _parse_server_raw(server):
|
||||
"""Parses a list of server directives.
|
||||
|
||||
|
||||
@@ -7,6 +7,8 @@ import six
|
||||
|
||||
from certbot import errors
|
||||
|
||||
from certbot_nginx import nginxparser
|
||||
|
||||
from acme.magic_typing import List # pylint: disable=unused-import, no-name-in-module
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -18,25 +20,19 @@ class Parsable(object):
|
||||
""" Abstract base class for "Parsable" objects whose underlying representation
|
||||
is a tree of lists.
|
||||
|
||||
:param .Parsable parent: This object's parsed parent in the tree
|
||||
:param .ParseContext context: This object's context
|
||||
"""
|
||||
|
||||
__metaclass__ = abc.ABCMeta
|
||||
|
||||
def __init__(self, parent=None):
|
||||
def __init__(self, context=None):
|
||||
self._data = [] # type: List[object]
|
||||
self._tabs = None
|
||||
self.parent = parent
|
||||
self.context = context
|
||||
|
||||
@classmethod
|
||||
def parsing_hooks(cls):
|
||||
"""Returns object types that this class should be able to `parse` recusrively.
|
||||
The order of the objects indicates the order in which the parser should
|
||||
try to parse each subitem.
|
||||
:returns: A list of Parsable classes.
|
||||
:rtype list:
|
||||
"""
|
||||
return (Block, Sentence, Statements)
|
||||
def get_parent(self):
|
||||
if self.context == None:
|
||||
return None
|
||||
return self.context.parent
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
@@ -76,29 +72,6 @@ class Parsable(object):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_tabs(self):
|
||||
""" Guess at the tabbing style of this parsed object, based on whitespace.
|
||||
|
||||
If this object is a leaf, it deducts the tabbing based on its own contents.
|
||||
Other objects may guess by calling `get_tabs` recursively on child objects.
|
||||
|
||||
:returns: Guess at tabbing for this object. Should only return whitespace strings
|
||||
that does not contain newlines.
|
||||
:rtype str:
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_tabs(self, tabs=" "):
|
||||
"""This tries to set and alter the tabbing of the current object to a desired
|
||||
whitespace string. Primarily meant for objects that were constructed, so they
|
||||
can conform to surrounding whitespace.
|
||||
|
||||
:param str tabs: A whitespace string (not containing newlines).
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def dump(self, include_spaces=False):
|
||||
""" Dumps back to pyparsing-like list tree. The opposite of `parse`.
|
||||
|
||||
@@ -113,16 +86,42 @@ class Parsable(object):
|
||||
"""
|
||||
return [elem.dump(include_spaces) for elem in self._data]
|
||||
|
||||
def dump_unspaced_list(self):
|
||||
""" Dumps back to pyparsing-like list tree into an UnspacedList.
|
||||
Use for compatibility with UnspacedList dependencies while migrating
|
||||
to new parsing objects.
|
||||
|
||||
class Statements(Parsable):
|
||||
""" A group or list of "Statements". A Statement is either a Block or a Sentence.
|
||||
:returns: Pyparsing-like list tree.
|
||||
:rtype :class:`.nginxparser.UnspacedList`:
|
||||
"""
|
||||
return nginxparser.UnspacedList(self.dump(True))
|
||||
|
||||
The underlying representation is simply a list of these Statement objects, with
|
||||
|
||||
def child_context(self, filename=None, cwd=None):
|
||||
""" Spawn a child context. """
|
||||
if self.context:
|
||||
return self.context.child(self, filename=filename)
|
||||
return ParseContext(parent=self, filename=filename, cwd=cwd)
|
||||
|
||||
def get_path(self):
|
||||
""" TODO: document and test"""
|
||||
if not self.context.parent or self.context.parent.context.filename != self.context.filename:
|
||||
return None
|
||||
parent_path = self.context.parent.get_path()
|
||||
my_index = self.context.parent._data.index(self)
|
||||
if parent_path:
|
||||
return parent_path + [my_index]
|
||||
return [my_index]
|
||||
|
||||
class Directives(Parsable):
|
||||
""" A group or list of Directives.
|
||||
|
||||
The underlying representation is simply a list of other parsed objects, with
|
||||
an extra `_trailing_whitespace` string to keep track of the whitespace that does not
|
||||
precede any more statements.
|
||||
"""
|
||||
def __init__(self, parent=None):
|
||||
super(Statements, self).__init__(parent)
|
||||
def __init__(self, context=None):
|
||||
super(Directives, self).__init__(context)
|
||||
self._trailing_whitespace = None
|
||||
|
||||
# ======== Begin overridden functions
|
||||
@@ -131,43 +130,32 @@ class Statements(Parsable):
|
||||
def should_parse(lists):
|
||||
return isinstance(lists, list)
|
||||
|
||||
def set_tabs(self, tabs=" "):
|
||||
""" Sets the tabbing for this set of statements. Does this by calling `set_tabs`
|
||||
on each of the child statements.
|
||||
|
||||
Then, if a parent is present, sets trailing whitespace to parent tabbing. This
|
||||
is so that the trailing } of any Block that contains Statements lines up
|
||||
with parent tabbing.
|
||||
"""
|
||||
for statement in self._data:
|
||||
statement.set_tabs(tabs)
|
||||
if self.parent is not None:
|
||||
self._trailing_whitespace = "\n" + self.parent.get_tabs()
|
||||
|
||||
def parse(self, raw_list, add_spaces=False):
|
||||
""" Parses a list of statements.
|
||||
Expects all elements in `raw_list` to be parseable by `type(self).parsing_hooks`,
|
||||
with an optional whitespace string at the last index of `raw_list`.
|
||||
"""
|
||||
if isinstance(raw_list, nginxparser.UnspacedList):
|
||||
raw_list = raw_list.spaced
|
||||
if not isinstance(raw_list, list):
|
||||
raise errors.MisconfigurationError("Statements parsing expects a list!")
|
||||
# If there's a trailing whitespace in the list of statements, keep track of it.
|
||||
if raw_list and isinstance(raw_list[-1], six.string_types) and raw_list[-1].isspace():
|
||||
self._trailing_whitespace = raw_list[-1]
|
||||
raw_list = raw_list[:-1]
|
||||
self._data = [parse_raw(elem, self, add_spaces) for elem in raw_list]
|
||||
|
||||
def get_tabs(self):
|
||||
""" Takes a guess at the tabbing of all contained Statements by retrieving the
|
||||
tabbing of the first Statement."""
|
||||
if self._data:
|
||||
return self._data[0].get_tabs()
|
||||
return ""
|
||||
raise errors.MisconfigurationError("Directives parsing expects a list!")
|
||||
# If there's a trailing whitespace in the list of statements, keep track of them
|
||||
if raw_list:
|
||||
i = -1
|
||||
while len(raw_list) >= -i and isinstance(raw_list[i], six.string_types) and raw_list[i].isspace():
|
||||
i -= 1
|
||||
self._trailing_whitespace = "".join(raw_list[i+1:])
|
||||
raw_list = raw_list[:i+1]
|
||||
# Create parsing objects first, then parse. Then references to parent
|
||||
# data exist while we parse the child objects.
|
||||
self._data = [_choose_parser(self.child_context(), elem) for elem in raw_list]
|
||||
for i, elem in enumerate(raw_list):
|
||||
self._data[i].parse(elem, add_spaces)
|
||||
|
||||
def dump(self, include_spaces=False):
|
||||
""" Dumps this object by first dumping each statement, then appending its
|
||||
trailing whitespace (if `include_spaces` is set) """
|
||||
data = super(Statements, self).dump(include_spaces)
|
||||
data = super(Directives, self).dump(include_spaces)
|
||||
if include_spaces and self._trailing_whitespace is not None:
|
||||
return data + [self._trailing_whitespace]
|
||||
return data
|
||||
@@ -180,6 +168,37 @@ class Statements(Parsable):
|
||||
|
||||
# ======== End overridden functions
|
||||
|
||||
def update_directive(self, statement, index):
|
||||
""" upd8
|
||||
"""
|
||||
self._data[index] = statement
|
||||
if index + 1 >= len(self._data) or not _is_certbot_comment(self._data[index+1]):
|
||||
self._data.insert(index+1, _certbot_comment(self.context))
|
||||
|
||||
def find_directive(self, match_func):
|
||||
for i, elem in enumerate(self._data):
|
||||
if isinstance(elem, Sentence) and match_func(elem):
|
||||
return i
|
||||
return -1
|
||||
|
||||
def add_directive(self, statement, insert_at_top=False):
|
||||
""" Takes in a parse obj
|
||||
"""
|
||||
index = 0
|
||||
if insert_at_top:
|
||||
self._data.insert(0, statement)
|
||||
else:
|
||||
index = len(self._data)
|
||||
self._data.append(statement)
|
||||
if not _is_comment(statement):
|
||||
self._data.insert(index+1, _certbot_comment(self.context))
|
||||
|
||||
def get_type(self, match_type):
|
||||
""" TODO
|
||||
"""
|
||||
return self.iterate(expanded=True,
|
||||
match=lambda elem: isinstance(elem, match_type))
|
||||
|
||||
|
||||
def _space_list(list_):
|
||||
""" Inserts whitespace between adjacent non-whitespace tokens. """
|
||||
@@ -220,31 +239,15 @@ class Sentence(Parsable):
|
||||
|
||||
def iterate(self, expanded=False, match=None):
|
||||
""" Simply yields itself. """
|
||||
if match is None or match(self):
|
||||
if (match is None) or match(self):
|
||||
yield self
|
||||
|
||||
def set_tabs(self, tabs=" "):
|
||||
""" Sets the tabbing on this sentence. Inserts a newline and `tabs` at the
|
||||
beginning of `self._data`. """
|
||||
if self._data[0].isspace():
|
||||
return
|
||||
self._data.insert(0, "\n" + tabs)
|
||||
|
||||
def dump(self, include_spaces=False):
|
||||
""" Dumps this sentence. If include_spaces is set, includes whitespace tokens."""
|
||||
if not include_spaces:
|
||||
return self.words
|
||||
return self._data
|
||||
|
||||
def get_tabs(self):
|
||||
""" Guesses at the tabbing of this sentence. If the first element is whitespace,
|
||||
returns the whitespace after the rightmost newline in the string. """
|
||||
first = self._data[0]
|
||||
if not first.isspace():
|
||||
return ""
|
||||
rindex = first.rfind("\n")
|
||||
return first[rindex+1:]
|
||||
|
||||
# ======== End overridden functions
|
||||
|
||||
@property
|
||||
@@ -252,6 +255,9 @@ class Sentence(Parsable):
|
||||
""" Iterates over words, but without spaces. Like Unspaced List. """
|
||||
return [word.strip("\"\'") for word in self._data if not word.isspace()]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.words)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.words[index]
|
||||
|
||||
@@ -270,8 +276,8 @@ class Block(Parsable):
|
||||
names = ["block", " ", "name", " "]
|
||||
contents = [["\n ", "content", " ", "1"], ["\n ", "content", " ", "2"], "\n"]
|
||||
"""
|
||||
def __init__(self, parent=None):
|
||||
super(Block, self).__init__(parent)
|
||||
def __init__(self, context=None):
|
||||
super(Block, self).__init__(context)
|
||||
self.names = None # type: Sentence
|
||||
self.contents = None # type: Block
|
||||
|
||||
@@ -279,7 +285,7 @@ class Block(Parsable):
|
||||
def should_parse(lists):
|
||||
""" Returns True if `lists` can be parseable as a `Block`-- that is,
|
||||
it's got a length of 2, the first element is a `Sentence` and the second can be
|
||||
a `Statements`.
|
||||
a `Directives`.
|
||||
|
||||
:param list lists: The raw unparsed list to check.
|
||||
|
||||
@@ -287,12 +293,6 @@ class Block(Parsable):
|
||||
return isinstance(lists, list) and len(lists) == 2 and \
|
||||
Sentence.should_parse(lists[0]) and isinstance(lists[1], list)
|
||||
|
||||
def set_tabs(self, tabs=" "):
|
||||
""" Sets tabs by setting equivalent tabbing on names, then adding tabbing
|
||||
to contents."""
|
||||
self.names.set_tabs(tabs)
|
||||
self.contents.set_tabs(tabs + " ")
|
||||
|
||||
def iterate(self, expanded=False, match=None):
|
||||
""" Iterator over self, and if expanded is set, over its contents. """
|
||||
if match is None or match(self):
|
||||
@@ -314,18 +314,14 @@ class Block(Parsable):
|
||||
if not Block.should_parse(raw_list):
|
||||
raise errors.MisconfigurationError("Block parsing expects a list of length 2. "
|
||||
"First element should be a list of string types (the bloc names), "
|
||||
"and second should be another list of statements (the bloc content).")
|
||||
self.names = Sentence(self)
|
||||
"and second should be another list of directives (the bloc content).")
|
||||
self.names = Sentence(self.child_context())
|
||||
self.contents = Directives(self.child_context())
|
||||
self._data = [self.names, self.contents]
|
||||
if add_spaces:
|
||||
raw_list[0].append(" ")
|
||||
self.names.parse(raw_list[0], add_spaces)
|
||||
self.contents = Statements(self)
|
||||
self.contents.parse(raw_list[1], add_spaces)
|
||||
self._data = [self.names, self.contents]
|
||||
|
||||
def get_tabs(self):
|
||||
""" Guesses tabbing by retrieving tabbing guess of self.names. """
|
||||
return self.names.get_tabs()
|
||||
|
||||
def _is_comment(parsed_obj):
|
||||
""" Checks whether parsed_obj is a comment.
|
||||
@@ -356,34 +352,34 @@ def _is_certbot_comment(parsed_obj):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _certbot_comment(parent, preceding_spaces=4):
|
||||
def _certbot_comment(context, preceding_spaces=4):
|
||||
""" A "Managed by Certbot" comment.
|
||||
:param int preceding_spaces: Number of spaces between the end of the previous
|
||||
statement and the comment.
|
||||
:returns: Sentence containing the comment.
|
||||
:rtype: .Sentence
|
||||
"""
|
||||
result = Sentence(parent)
|
||||
result = Sentence(context)
|
||||
result.parse([" " * preceding_spaces] + COMMENT_BLOCK)
|
||||
return result
|
||||
|
||||
def _choose_parser(parent, list_):
|
||||
""" Choose a parser from type(parent).parsing_hooks, depending on whichever hook
|
||||
def _choose_parser(context, list_):
|
||||
""" Choose a parser from type(context).parsing_hooks, depending on whichever hook
|
||||
returns True first. """
|
||||
hooks = Parsable.parsing_hooks()
|
||||
if parent:
|
||||
hooks = type(parent).parsing_hooks()
|
||||
hooks = ParseContext.parsing_hooks()
|
||||
if context:
|
||||
hooks = type(context).parsing_hooks()
|
||||
for type_ in hooks:
|
||||
if type_.should_parse(list_):
|
||||
return type_(parent)
|
||||
return type_(context)
|
||||
raise errors.MisconfigurationError(
|
||||
"None of the parsing hooks succeeded, so we don't know how to parse this set of lists.")
|
||||
|
||||
def parse_raw(lists_, parent=None, add_spaces=False):
|
||||
def parse_raw(lists_, context=None, add_spaces=False):
|
||||
""" Primary parsing factory function.
|
||||
|
||||
:param list lists_: raw lists from pyparsing to parse.
|
||||
:param .Parent parent: The parent containing this object.
|
||||
:param .ParseContext context: The context of this object.
|
||||
:param bool add_spaces: Whether to pass add_spaces to the parser.
|
||||
|
||||
:returns .Parsable: The parsed object.
|
||||
@@ -391,6 +387,38 @@ def parse_raw(lists_, parent=None, add_spaces=False):
|
||||
:raises errors.MisconfigurationError: If no parsing hook passes, and we can't
|
||||
determine which type to parse the raw lists into.
|
||||
"""
|
||||
parser = _choose_parser(parent, lists_)
|
||||
if context is None:
|
||||
context = ParseContext()
|
||||
parser = _choose_parser(context, lists_)
|
||||
parser.parse(lists_, add_spaces)
|
||||
return parser
|
||||
|
||||
class ParseContext(object):
|
||||
""" Context information held by parsed objects.
|
||||
|
||||
:param .Parsable parent: The parent object containing the associated object.
|
||||
:param str filename: relative file path that the associated object was parsed from
|
||||
:param str cwd: current working directory/root of the parsing files
|
||||
"""
|
||||
|
||||
__metaclass__ = abc.ABCMeta
|
||||
|
||||
def __init__(self, parent=None, filename=None, cwd=None):
|
||||
self.parent = parent
|
||||
self.filename = filename
|
||||
self.cwd = cwd
|
||||
|
||||
def child(self, parent, filename=None):
|
||||
""" Returns Context with all fields inherited, except parent points to this object.
|
||||
"""
|
||||
return ParseContext(parent, filename if filename else self.filename, self.cwd)
|
||||
|
||||
@staticmethod
|
||||
def parsing_hooks():
|
||||
"""Returns object types that this class should be able to `parse` recusrively.
|
||||
The order of the objects indicates the order in which the parser should
|
||||
try to parse each subitem.
|
||||
:returns: A list of Parsable classes.
|
||||
:rtype list:
|
||||
"""
|
||||
return (Block, Sentence, Directives)
|
||||
|
||||
@@ -177,6 +177,7 @@ class NginxConfiguratorTest(util.NginxTest):
|
||||
if name == "ipv6.com":
|
||||
self.assertTrue(vhost.ipv6_enabled())
|
||||
# Make sure that we have SSL enabled also for IPv6 addr
|
||||
print vhost
|
||||
self.assertTrue(
|
||||
any([True for x in vhost.addrs if x.ssl and x.ipv6]))
|
||||
|
||||
@@ -494,23 +495,23 @@ class NginxConfiguratorTest(util.NginxTest):
|
||||
self.assertEqual(
|
||||
[[['server'], [
|
||||
['server_name', '.example.com'],
|
||||
['server_name', 'example.*'], [],
|
||||
['server_name', 'example.*'],
|
||||
['listen', '5001', 'ssl'], ['#', ' managed by Certbot'],
|
||||
['ssl_certificate', 'example/fullchain.pem'], ['#', ' managed by Certbot'],
|
||||
['ssl_certificate_key', 'example/key.pem'], ['#', ' managed by Certbot'],
|
||||
['include', self.config.mod_ssl_conf], ['#', ' managed by Certbot'],
|
||||
['ssl_dhparam', self.config.ssl_dhparams], ['#', ' managed by Certbot'],
|
||||
[], []]],
|
||||
]],
|
||||
[['server'], [
|
||||
[['if', '($host', '=', 'www.example.com)'], [
|
||||
['return', '301', 'https://$host$request_uri']]],
|
||||
['#', ' managed by Certbot'], [],
|
||||
['#', ' managed by Certbot'],
|
||||
['listen', '69.50.225.155:9000'],
|
||||
['listen', '127.0.0.1'],
|
||||
['server_name', '.example.com'],
|
||||
['server_name', 'example.*'],
|
||||
['return', '404'], ['#', ' managed by Certbot'], [], [], []]]],
|
||||
generated_conf)
|
||||
['return', '404'], ['#', ' managed by Certbot'], ]]][0],
|
||||
generated_conf[0])
|
||||
|
||||
def test_split_for_headers(self):
|
||||
example_conf = self.config.parser.abs_path('sites-enabled/example.com')
|
||||
@@ -525,22 +526,21 @@ class NginxConfiguratorTest(util.NginxTest):
|
||||
self.assertEqual(
|
||||
[[['server'], [
|
||||
['server_name', '.example.com'],
|
||||
['server_name', 'example.*'], [],
|
||||
['server_name', 'example.*'],
|
||||
['listen', '5001', 'ssl'], ['#', ' managed by Certbot'],
|
||||
['ssl_certificate', 'example/fullchain.pem'], ['#', ' managed by Certbot'],
|
||||
['ssl_certificate_key', 'example/key.pem'], ['#', ' managed by Certbot'],
|
||||
['include', self.config.mod_ssl_conf], ['#', ' managed by Certbot'],
|
||||
['ssl_dhparam', self.config.ssl_dhparams], ['#', ' managed by Certbot'],
|
||||
[], [],
|
||||
['add_header', 'Strict-Transport-Security', '"max-age=31536000"', 'always'],
|
||||
['#', ' managed by Certbot'],
|
||||
[], []]],
|
||||
]],
|
||||
[['server'], [
|
||||
['listen', '69.50.225.155:9000'],
|
||||
['listen', '127.0.0.1'],
|
||||
['server_name', '.example.com'],
|
||||
['server_name', 'example.*'],
|
||||
[], [], []]]],
|
||||
]]],
|
||||
generated_conf)
|
||||
|
||||
def test_http_header_hsts(self):
|
||||
@@ -633,6 +633,7 @@ class NginxConfiguratorTest(util.NginxTest):
|
||||
default_conf = self.config.parser.abs_path('sites-enabled/default')
|
||||
foo_conf = self.config.parser.abs_path('foo.conf')
|
||||
del self.config.parser.parsed[foo_conf][2][1][0][1][0] # remove default_server
|
||||
self.config.parser.get_vhost(foo_conf).raw_obj.parse(self.config.parser.parsed[foo_conf][2][1][0])
|
||||
self.config.version = (1, 3, 1)
|
||||
|
||||
self.config.deploy_cert(
|
||||
@@ -687,6 +688,7 @@ class NginxConfiguratorTest(util.NginxTest):
|
||||
foo_conf = self.config.parser.abs_path('foo.conf')
|
||||
del self.config.parser.parsed[default_conf][0][1][0]
|
||||
del self.config.parser.parsed[default_conf][0][1][0]
|
||||
self.config.parser.get_vhost(default_conf).raw_obj.parse(self.config.parser.parsed[default_conf][0])
|
||||
self.config.version = (1, 3, 1)
|
||||
|
||||
self.config.deploy_cert(
|
||||
@@ -721,6 +723,8 @@ class NginxConfiguratorTest(util.NginxTest):
|
||||
del self.config.parser.parsed[default_conf][0][1][0]
|
||||
del self.config.parser.parsed[default_conf][0][1][0]
|
||||
del self.config.parser.parsed[foo_conf][2][1][0][1][0]
|
||||
self.config.parser.get_vhost(foo_conf).raw_obj.parse(self.config.parser.parsed[foo_conf][2][1][0])
|
||||
self.config.parser.get_vhost(default_conf).raw_obj.parse(self.config.parser.parsed[default_conf][0])
|
||||
self.config.version = (1, 3, 1)
|
||||
|
||||
self.assertRaises(errors.MisconfigurationError, self.config.deploy_cert,
|
||||
@@ -736,6 +740,7 @@ class NginxConfiguratorTest(util.NginxTest):
|
||||
def test_deploy_no_match_multiple_defaults_ok(self):
|
||||
foo_conf = self.config.parser.abs_path('foo.conf')
|
||||
self.config.parser.parsed[foo_conf][2][1][0][1][0][1] = '*:5001'
|
||||
self.config.parser.get_vhost(foo_conf).raw_obj.parse(self.config.parser.parsed[foo_conf][2][1][0])
|
||||
self.config.version = (1, 3, 1)
|
||||
self.config.deploy_cert("www.nomatch.com", "example/cert.pem", "example/key.pem",
|
||||
"example/chain.pem", "example/fullchain.pem")
|
||||
@@ -744,6 +749,7 @@ class NginxConfiguratorTest(util.NginxTest):
|
||||
default_conf = self.config.parser.abs_path('sites-enabled/default')
|
||||
foo_conf = self.config.parser.abs_path('foo.conf')
|
||||
del self.config.parser.parsed[foo_conf][2][1][0][1][0] # remove default_server
|
||||
self.config.parser.get_vhost(foo_conf).raw_obj.parse(self.config.parser.parsed[foo_conf][2][1][0])
|
||||
self.config.version = (1, 3, 1)
|
||||
|
||||
self.config.deploy_cert(
|
||||
|
||||
@@ -53,7 +53,7 @@ class ParsingHooksTest(unittest.TestCase):
|
||||
self.assertFalse(Block.should_parse([['block_name'], 'lol']))
|
||||
self.assertTrue(Block.should_parse([['block_name'], ['hi', []]]))
|
||||
self.assertTrue(Block.should_parse([['hello'], []]))
|
||||
self.assertTrue(Block.should_parse([['block_name'], [['many'], ['statements'], 'here']]))
|
||||
self.assertTrue(Block.should_parse([['block_name'], [['many'], ['directives'], 'here']]))
|
||||
self.assertTrue(Block.should_parse([['if', ' ', '(whatever)'], ['hi']]))
|
||||
|
||||
def test_parse_raw(self):
|
||||
@@ -71,7 +71,7 @@ class ParsingHooksTest(unittest.TestCase):
|
||||
fake_parser1.not_called()
|
||||
fake_parser2.called_once()
|
||||
|
||||
@mock.patch("certbot_nginx.parser_obj.Parsable.parsing_hooks")
|
||||
@mock.patch("certbot_nginx.parser_obj.ParseContext.parsing_hooks")
|
||||
def test_parse_raw_no_match(self, parsing_hooks):
|
||||
from certbot import errors
|
||||
fake_parser1 = mock.Mock()
|
||||
@@ -119,22 +119,6 @@ class SentenceTest(unittest.TestCase):
|
||||
for i, sentence in enumerate(self.sentence.iterate()):
|
||||
self.assertEqual(sentence.dump(), expected[i])
|
||||
|
||||
def test_set_tabs(self):
|
||||
self.sentence.parse(['tabs', 'pls'], add_spaces=True)
|
||||
self.sentence.set_tabs()
|
||||
self.assertEqual(self.sentence.dump(True)[0], '\n ')
|
||||
self.sentence.parse(['tabs', 'pls'], add_spaces=True)
|
||||
|
||||
def test_get_tabs(self):
|
||||
self.sentence.parse(['no', 'tabs'])
|
||||
self.assertEqual(self.sentence.get_tabs(), '')
|
||||
self.sentence.parse(['\n \n ', 'tabs'])
|
||||
self.assertEqual(self.sentence.get_tabs(), ' ')
|
||||
self.sentence.parse(['\n\t ', 'tabs'])
|
||||
self.assertEqual(self.sentence.get_tabs(), '\t ')
|
||||
self.sentence.parse(['\n\t \n', 'tabs'])
|
||||
self.assertEqual(self.sentence.get_tabs(), '')
|
||||
|
||||
class BlockTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
from certbot_nginx.parser_obj import Block
|
||||
@@ -179,21 +163,10 @@ class BlockTest(unittest.TestCase):
|
||||
self.assertRaises(errors.MisconfigurationError, self.bloc.parse, ['lol'])
|
||||
self.assertRaises(errors.MisconfigurationError, self.bloc.parse, ['fake', 'news'])
|
||||
|
||||
def test_set_tabs(self):
|
||||
self.bloc.set_tabs()
|
||||
self.assertEqual(self.bloc.names.dump(True)[0], '\n ')
|
||||
for elem in self.bloc.contents.dump(True)[:-1]:
|
||||
self.assertEqual(elem[0], '\n ')
|
||||
self.assertEqual(self.bloc.contents.dump(True)[-1][0], '\n')
|
||||
|
||||
def test_get_tabs(self):
|
||||
self.bloc.parse([[' \n \t', 'lol'], []])
|
||||
self.assertEqual(self.bloc.get_tabs(), ' \t')
|
||||
|
||||
class StatementsTest(unittest.TestCase):
|
||||
class DirectivesTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
from certbot_nginx.parser_obj import Statements
|
||||
self.statements = Statements(None)
|
||||
from certbot_nginx.parser_obj import Directives
|
||||
self.directives = Directives(None)
|
||||
self.raw = [
|
||||
['sentence', 'one'],
|
||||
['sentence', 'two'],
|
||||
@@ -206,47 +179,24 @@ class StatementsTest(unittest.TestCase):
|
||||
'\n\n'
|
||||
]
|
||||
|
||||
def test_set_tabs(self):
|
||||
self.statements.parse(self.raw)
|
||||
self.statements.set_tabs()
|
||||
for statement in self.statements.iterate():
|
||||
self.assertEqual(statement.dump(True)[0], '\n ')
|
||||
|
||||
def test_set_tabs_with_parent(self):
|
||||
# Trailing whitespace should inherit from parent tabbing.
|
||||
self.statements.parse(self.raw)
|
||||
self.statements.parent = mock.Mock()
|
||||
self.statements.parent.get_tabs.return_value = '\t\t'
|
||||
self.statements.set_tabs()
|
||||
for statement in self.statements.iterate():
|
||||
self.assertEqual(statement.dump(True)[0], '\n ')
|
||||
self.assertEqual(self.statements.dump(True)[-1], '\n\t\t')
|
||||
|
||||
def test_get_tabs(self):
|
||||
self.raw[0].insert(0, '\n \n \t')
|
||||
self.statements.parse(self.raw)
|
||||
self.assertEqual(self.statements.get_tabs(), ' \t')
|
||||
self.statements.parse([])
|
||||
self.assertEqual(self.statements.get_tabs(), '')
|
||||
|
||||
def test_parse_with_added_spaces(self):
|
||||
self.statements.parse(self.raw, add_spaces=True)
|
||||
self.assertEqual(self.statements.dump(True)[0], ['sentence', ' ', 'one'])
|
||||
self.directives.parse(self.raw, add_spaces=True)
|
||||
self.assertEqual(self.directives.dump(True)[0], ['sentence', ' ', 'one'])
|
||||
|
||||
def test_parse_bad_list_raises_error(self):
|
||||
from certbot import errors
|
||||
self.assertRaises(errors.MisconfigurationError, self.statements.parse, 'lol not a list')
|
||||
self.assertRaises(errors.MisconfigurationError, self.directives.parse, 'lol not a list')
|
||||
|
||||
def test_parse_hides_trailing_whitespace(self):
|
||||
self.statements.parse(self.raw + ['\n\n '])
|
||||
self.assertTrue(isinstance(self.statements.dump()[-1], list))
|
||||
self.assertTrue(self.statements.dump(True)[-1].isspace())
|
||||
self.assertEqual(self.statements.dump(True)[-1], '\n\n ')
|
||||
self.directives.parse(self.raw + ['\n\n '])
|
||||
self.assertTrue(isinstance(self.directives.dump()[-1], list))
|
||||
self.assertTrue(self.directives.dump(True)[-1].isspace())
|
||||
self.assertEqual(self.directives.dump(True)[-1], '\n\n ')
|
||||
|
||||
def test_iterate(self):
|
||||
self.statements.parse(self.raw)
|
||||
self.directives.parse(self.raw)
|
||||
expected = [['sentence', 'one'], ['sentence', 'two']]
|
||||
for i, elem in enumerate(self.statements.iterate(match=lambda x: 'sentence' in x)):
|
||||
for i, elem in enumerate(self.directives.iterate(match=lambda x: 'sentence' in x)):
|
||||
self.assertEqual(expected[i], elem.dump())
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -365,7 +365,6 @@ class NginxParserTest(util.NginxTest): #pylint: disable=too-many-public-methods
|
||||
self.assertEqual(block.spaced, [
|
||||
["\n", "a", " ", "b", "\n"],
|
||||
COMMENT_BLOCK,
|
||||
"\n",
|
||||
["c", " ", "d"],
|
||||
COMMENT_BLOCK,
|
||||
["\n", "e", " ", "f"]])
|
||||
|
||||
Reference in New Issue
Block a user