Compare commits

...

4 Commits

Author SHA1 Message Date
sydneyli
fcc76618fa Everything working with add_directive except save 2019-06-07 16:48:43 -07:00
sydneyli
7a79c55af8 Create and load parser objects 2019-06-07 16:48:08 -07:00
sydneyli
e52cd73b84 fixes to UnspacedList to preserve newlines at the end of directive blocks 2019-06-07 11:59:02 -07:00
sydneyli
6c6dd3dd1a Allow nginxparser to load to & dump from regular lists 2019-06-07 11:59:02 -07:00
8 changed files with 430 additions and 242 deletions

View File

@@ -0,0 +1,172 @@
"""NginxParser is a member object of the NginxConfigurator class."""
import glob
import logging
import pyparsing
from certbot import errors
from certbot.compat import os
from certbot_nginx import nginxparser
from certbot_nginx import parser_obj as obj
from certbot_nginx import obj as nginx_obj
logger = logging.getLogger(__name__)
class NginxParseContext(obj.ParseContext):
""" A parsing context which includes a set of parsing hooks specific to Nginx
configuration files. """
def __init__(self, parent=None, filename=None, cwd=None, parsed_files=None):
super(NginxParseContext, self).__init__(parent, filename, cwd)
self.parsed_files = parsed_files if parsed_files else {}
@staticmethod
def parsing_hooks():
return NGINX_PARSING_HOOKS
def child(self, parent, filename=None):
return NginxParseContext(parent,filename if filename else self.filename,
self.cwd, self.parsed_files)
def parse_from_file_nginx(context):
""" Parses from a file specified by `context`.
:param NginxParseContext context:
:returns WithLists:
"""
raw_parsed = []
with open(os.path.join(context.cwd, context.filename)) as _file:
try:
raw_parsed = nginxparser.load(_file, True)
except pyparsing.ParseException as err:
logger.debug("Could not parse file: %s due to %s", context.filename, err)
context.parsed_files[context.filename] = None
parsed = obj.parse_raw(raw_parsed, context=context)
parsed.context.parsed_files[context.filename] = parsed
return parsed
class Include(obj.Sentence):
""" Represents an include statement. On parsing, tries to read and parse included file(s), while
avoiding duplicates from `context.parsed`."""
def __init__(self, context):
super(Include, self).__init__(context)
self.parsed = dict()
@staticmethod
def should_parse(lists):
return obj.Sentence.should_parse(lists) and "include" in lists
def parse(self, raw_list, add_spaces=False):
""" Parsing an include this will try to fetch the associated files (if they exist)
and parses them all. Any parsed files are added to the global context.parsed_files object.
"""
super(Include, self).parse(raw_list, add_spaces)
filepath = self.filename
if not os.path.isabs(filepath):
filepath = os.path.join(self.context.cwd, self.filename)
for f in glob.glob(filepath):
self.parsed[f] = self.context.parsed_files[f] if f in self.context.parsed_files else \
parse_from_file_nginx(self.child_context(f))
@property
def filename(self):
""" Retrieves the filename that is being included. """
return self.words[1]
def iterate(self, expanded=False, match=None):
""" Iterates itself, and if expanded is set, iterates over the `Directives` objects
in all of the included files.
"""
if match is None or match(self):
yield self
if expanded:
for parsed in self.parsed.values():
for sub_elem in parsed.iterate(expanded, match):
yield sub_elem
class ServerBlock(obj.Block):
""" Parsing object which represents an Nginx server block.
This bloc should parallel a "VirtualHost" object-- any update or modification should
also update the corresponding virtual host object. """
REPEATABLE_DIRECTIVES = set(['server_name', 'listen', 'include', 'rewrite', 'add_header'])
def __init__(self, context=None):
super(ServerBlock, self).__init__(context)
self.vhost = None
self.addrs = set()
self.ssl = False
self.server_names = set()
@staticmethod
def should_parse(lists):
return obj.Block.should_parse(lists) and "server" in lists[0]
def _update_vhost(self):
# copied from _parse_server_raw
self.addrs = set()
self.ssl = False
self.server_names = set()
apply_ssl_to_all_addrs = False
for directive in self.contents.get_type(obj.Sentence):
if len(directive.words) == 0:
continue
if directive[0] == 'listen':
addr = nginx_obj.Addr.fromstring(" ".join(directive[1:]))
if addr:
self.addrs.add(addr)
if addr.ssl:
self.ssl = True
if directive[0] == 'server_name':
self.server_names.update(x.strip('"\'') for x in directive[1:])
for ssl in self.get_directives('ssl'):
if ssl[1] == "on":
self.ssl = True
apply_ssl_to_all_addrs = True
if apply_ssl_to_all_addrs:
for addr in self.addrs:
addr.ssl = True
self.vhost.addrs = self.addrs
self.vhost.names = self.server_names
self.vhost.ssl = self.ssl
self.vhost.raw = self.dump_unspaced_list()[1]
self.vhost.raw_obj = self
def add_directive(self, raw_list, insert_at_top=False):
""" Adds a single directive to this Server Block's contents, while enforcing
repeatability rules."""
statement = obj.parse_raw(raw_list, self.contents.child_context(), add_spaces=False)
if isinstance(statement, obj.Sentence) and statement[0] not in self.REPEATABLE_DIRECTIVES \
and len(list(self.get_directives(statement[0]))) > 0:
raise errors.MisconfigurationError(
"Existing %s directive conflicts with %s" % (statement[0], statement))
self.contents.add_directive(statement, insert_at_top)
def update_or_add_directive(self, raw_list, insert_at_top=False):
""" Adds a single directive to this Server Block's contents, while enforcing
repeatability rules."""
statement = obj.parse_raw(raw_list, self.contents.child_context(), add_spaces=False)
index = self.contents.find_directive(lambda elem: elem[0] == statement[0])
if index < 0:
self.contents.add_directive(statement, insert_at_top)
return
self.contents.update_directive(statement, index)
def get_directives(self, name, match=None):
""" Retrieves any child directive starting with `name`.
:param str name: The directive name to fetch.
:param callable match: An additional optional filter to specify matching directives.
:return: an iterator over matching directives.
"""
directives = self.contents.get_type(obj.Sentence)
return [d for d in directives if len(d) > 0 and d[0] == name and (match is None or match(d))]
def parse(self, raw_list, add_spaces=False):
""" Parses lists into a ServerBlock object, and creates a
corresponding VirtualHost metadata object. """
super(ServerBlock, self).parse(raw_list, add_spaces)
self.vhost = nginx_obj.VirtualHost(
self.context.filename if self.context is not None else "",
self.addrs, self.ssl, True, self.server_names, self.dump_unspaced_list()[1],
self.get_path(), self)
self._update_vhost()
NGINX_PARSING_HOOKS = (ServerBlock, obj.Block, Include, obj.Sentence, obj.Directives)

View File

@@ -101,36 +101,42 @@ class RawNginxDumper(object):
# Shortcut functions to respect Python's serialization interface
# (like pyyaml, picker or json)
def loads(source):
def loads(source, raw=False):
"""Parses from a string.
:param str source: The string to parse
:param bool raw: If true, doesn't return an UnspacedList.
:returns: The parsed tree
:rtype: list
"""
if raw:
return RawNginxParser(source).as_list()
return UnspacedList(RawNginxParser(source).as_list())
def load(_file):
def load(_file, raw=False):
"""Parses from a file.
:param file _file: The file to parse
:param bool raw: If true, doesn't return an UnspacedList.
:returns: The parsed tree
:rtype: list
"""
return loads(_file.read())
return loads(_file.read(), raw)
def dumps(blocks):
def dumps(blocks, raw=False):
"""Dump to a string.
:param UnspacedList block: The parsed tree
:param int indentation: The number of spaces to indent
:param UnspacedList or list block: The parsed tree
:param bool raw: If true, expects a regular list, not UnspacedList.
:rtype: str
"""
if raw:
return str(RawNginxDumper(blocks))
return str(RawNginxDumper(blocks.spaced))
@@ -189,6 +195,8 @@ class UnspacedList(list):
def insert(self, i, x):
item, spaced_item = self._coerce(x)
slicepos = self._spaced_position(i) if i < len(self) else len(self.spaced)
if slicepos > 0 and spacey(self.spaced[slicepos-1]):
slicepos -= 1
self.spaced.insert(slicepos, spaced_item)
if not spacey(item):
list.insert(self, i, item)
@@ -196,14 +204,20 @@ class UnspacedList(list):
def append(self, x):
item, spaced_item = self._coerce(x)
self.spaced.append(spaced_item)
if len(self.spaced) > 0 and spacey(self.spaced[-1]):
self.spaced.insert(-1, spaced_item)
else:
self.spaced.append(spaced_item)
if not spacey(item):
list.append(self, item)
self.dirty = True
def extend(self, x):
item, spaced_item = self._coerce(x)
self.spaced.extend(spaced_item)
if len(self.spaced) > 0 and spacey(self.spaced[-1]):
self.spaced[-1:-1] = spaced_item
else:
self.spaced.extend(spaced_item)
list.extend(self, item)
self.dirty = True

View File

@@ -161,7 +161,7 @@ class VirtualHost(object): # pylint: disable=too-few-public-methods
"""
def __init__(self, filep, addrs, ssl, enabled, names, raw, path):
def __init__(self, filep, addrs, ssl, enabled, names, raw, path, raw_obj=None):
# pylint: disable=too-many-arguments
"""Initialize a VH."""
self.filep = filep
@@ -171,6 +171,7 @@ class VirtualHost(object): # pylint: disable=too-few-public-methods
self.enabled = enabled
self.raw = raw
self.path = path
self.raw_obj = raw_obj
def __str__(self):
addr_str = ", ".join(str(addr) for addr in sorted(self.addrs, key=str))

View File

@@ -12,6 +12,7 @@ from certbot import errors
from certbot.compat import os
from certbot_nginx import obj
from certbot_nginx import nginx_parser_obj as nginx_obj
from certbot_nginx import nginxparser
from acme.magic_typing import Union, Dict, Set, Any, List, Tuple # pylint: disable=unused-import, no-name-in-module
@@ -32,6 +33,8 @@ class NginxParser(object):
self.root = os.path.abspath(root)
self.config_root = self._find_config_root()
self.parsed_root = None
# Parse nginx.conf and included files.
# TODO: Check sites-available/ as well. For now, the configurator does
# not enable sites from there.
@@ -43,6 +46,10 @@ class NginxParser(object):
"""
self.parsed = {}
self._parse_recursively(self.config_root)
if self.parsed_root:
return
self.parsed_root = nginx_obj.parse_from_file_nginx(
nginx_obj.NginxParseContext(cwd=self.root, filename=self.config_root))
def _parse_recursively(self, filepath):
"""Parses nginx config files recursively by looking at 'include'
@@ -88,18 +95,14 @@ class NginxParser(object):
def _build_addr_to_ssl(self):
"""Builds a map from address to whether it listens on ssl in any server block
"""
servers = self._get_raw_servers()
addr_to_ssl = {} # type: Dict[Tuple[str, str], bool]
for filename in servers:
for server, _ in servers[filename]:
# Parse the server block to save addr info
parsed_server = _parse_server_raw(server)
for addr in parsed_server['addrs']:
addr_tuple = addr.normalized_tuple()
if addr_tuple not in addr_to_ssl:
addr_to_ssl[addr_tuple] = addr.ssl
addr_to_ssl[addr_tuple] = addr.ssl or addr_to_ssl[addr_tuple]
servers = self.parsed_root.get_type(nginx_obj.ServerBlock)
addr_to_ssl = {}
for server in servers:
for addr in server.vhost.addrs:
addr_tuple = addr.normalized_tuple()
if addr_tuple not in addr_to_ssl:
addr_to_ssl[addr_tuple] = addr.ssl
addr_to_ssl[addr_tuple] = addr.ssl or addr_to_ssl[addr_tuple]
return addr_to_ssl
def _get_raw_servers(self):
@@ -117,12 +120,20 @@ class NginxParser(object):
_do_for_subarray(tree, lambda x: len(x) >= 2 and x[0] == ['server'],
lambda x, y: srv.append((x[1], y)))
# Find 'include' statements in server blocks and append their trees
# Find 'include statement in server blocks and append their trees
for i, (server, path) in enumerate(servers[filename]):
new_server = self._get_included_directives(server)
servers[filename][i] = (new_server, path)
return servers
def get_vhost(self, filename):
""" Retrieve first found vhost in file for testing purposes. """
blocks = self.parsed_root.get_type(nginx_obj.ServerBlock)
for server_block in blocks:
if server_block.vhost.filep == filename:
return server_block.vhost
return None
def get_vhosts(self):
# pylint: disable=cell-var-from-loop
"""Gets list of all 'virtual hosts' found in Nginx configuration.
@@ -134,26 +145,11 @@ class NginxParser(object):
:rtype: list
"""
enabled = True # We only look at enabled vhosts for now
servers = self._get_raw_servers()
vhosts = []
for filename in servers:
for server, path in servers[filename]:
# Parse the server block into a VirtualHost object
parsed_server = _parse_server_raw(server)
vhost = obj.VirtualHost(filename,
parsed_server['addrs'],
parsed_server['ssl'],
enabled,
parsed_server['names'],
server,
path)
vhosts.append(vhost)
blocks = self.parsed_root.get_type(nginx_obj.ServerBlock)
for server_block in blocks:
vhosts.append(server_block.vhost)
self._update_vhosts_addrs_ssl(vhosts)
return vhosts
def _update_vhosts_addrs_ssl(self, vhosts):
@@ -295,8 +291,18 @@ class NginxParser(object):
of the server block instead of the bottom
"""
self._modify_server_directives(vhost,
functools.partial(_add_directives, directives, insert_at_top))
for directive in directives:
if not _is_whitespace_or_empty(directive):
vhost.raw_obj.add_directive(directive, insert_at_top)
self._sync_old_structs(vhost)
def _sync_old_structs(self, vhost):
vhost.raw_obj._update_vhost()
old = self.parsed[vhost.filep]
for index in vhost.path[:-1]:
old = old[index]
old[vhost.path[-1]] = vhost.raw_obj.dump_unspaced_list()
def update_or_add_server_directives(self, vhost, directives, insert_at_top=False):
"""Add or replace directives in the server block identified by vhost.
@@ -317,8 +323,10 @@ class NginxParser(object):
of the server block instead of the bottom
"""
self._modify_server_directives(vhost,
functools.partial(_update_or_add_directives, directives, insert_at_top))
for directive in directives:
if not _is_whitespace_or_empty(directive):
vhost.raw_obj.update_or_add_directive(directive, insert_at_top)
self._sync_old_structs(vhost)
def remove_server_directives(self, vhost, directive_name, match_func=None):
"""Remove all directives of type directive_name.
@@ -339,6 +347,8 @@ class NginxParser(object):
vhost.ssl = parsed_server['ssl']
vhost.names = parsed_server['names']
vhost.raw = new_server
if vhost.raw_obj:
vhost.raw_obj.contents.parse(vhost.raw.spaced)
def _modify_server_directives(self, vhost, block_func):
filename = vhost.filep
@@ -374,8 +384,10 @@ class NginxParser(object):
new_vhost = copy.deepcopy(vhost_template)
enclosing_block = self.parsed[vhost_template.filep]
parsing_obj = self.parsed_root.context.parsed_files[vhost_template.filep]
for index in vhost_template.path[:-1]:
enclosing_block = enclosing_block[index]
parsing_obj = parsing_obj._data[index]
raw_in_parsed = copy.deepcopy(enclosing_block[vhost_template.path[-1]])
if only_directives is not None:
@@ -383,11 +395,13 @@ class NginxParser(object):
for directive in raw_in_parsed[1]:
if directive and directive[0] in only_directives:
new_directives.append(directive)
new_directives.append("\n")
raw_in_parsed[1] = new_directives
self._update_vhost_based_on_new_directives(new_vhost, new_directives)
enclosing_block.append(raw_in_parsed)
if "\n" not in enclosing_block[-1][0][0]:
enclosing_block[-1][0].insert(0, "\n")
new_vhost.path[-1] = len(enclosing_block) - 1
if remove_singleton_listen_params:
for addr in new_vhost.addrs:
@@ -406,6 +420,11 @@ class NginxParser(object):
keys = [x.split('=')[0] for x in directive]
if param in keys:
del directive[keys.index(param)]
og_vhost_path = vhost_template.raw_obj.get_path()[-1]
parsing_obj.parse(enclosing_block.spaced)
new_vhost = parsing_obj._data[new_vhost.path[-1]].vhost
vhost_template.raw_obj = parsing_obj._data[og_vhost_path]
vhost_template.raw_obj.vhost = vhost_template
return new_vhost
@@ -552,18 +571,11 @@ def _is_ssl_on_directive(entry):
len(entry) == 2 and entry[0] == 'ssl' and
entry[1] == 'on')
def _add_directives(directives, insert_at_top, block):
"""Adds directives to a config block."""
for directive in directives:
_add_directive(block, directive, insert_at_top)
if block and '\n' not in block[-1]: # could be " \n " or ["\n"] !
block.append(nginxparser.UnspacedList('\n'))
def _update_or_add_directives(directives, insert_at_top, block):
"""Adds or replaces directives in a config block."""
for directive in directives:
_update_or_add_directive(block, directive, insert_at_top)
if block and '\n' not in block[-1]: # could be " \n " or ["\n"] !
if block and '\n' not in block.spaced[-1]: # could be " \n " or ["\n"] !
block.append(nginxparser.UnspacedList('\n'))
@@ -588,8 +600,6 @@ def comment_directive(block, location):
next_entry = next_entry[0]
block.insert(location + 1, COMMENT_BLOCK[:])
if next_entry is not None and "\n" not in next_entry:
block.insert(location + 2, '\n')
def _comment_out_directive(block, location, include_location):
"""Comment out the line at location, with a note of explanation."""
@@ -624,6 +634,7 @@ def _is_whitespace_or_comment(directive):
"""Is this directive either a whitespace or comment directive?"""
return len(directive) == 0 or directive[0] == '#'
# block = Statements
def _add_directive(block, directive, insert_at_top):
if not isinstance(directive, nginxparser.UnspacedList):
directive = nginxparser.UnspacedList(directive)
@@ -723,6 +734,13 @@ def _apply_global_addr_ssl(addr_to_ssl, parsed_server):
if addr.ssl:
parsed_server['ssl'] = True
def _is_whitespace_or_empty(directive):
if not directive:
return True
if len(directive) != 1:
return False
return len(directive[0]) == 0 or directive[0].isspace()
def _parse_server_raw(server):
"""Parses a list of server directives.

View File

@@ -7,6 +7,8 @@ import six
from certbot import errors
from certbot_nginx import nginxparser
from acme.magic_typing import List # pylint: disable=unused-import, no-name-in-module
logger = logging.getLogger(__name__)
@@ -18,25 +20,19 @@ class Parsable(object):
""" Abstract base class for "Parsable" objects whose underlying representation
is a tree of lists.
:param .Parsable parent: This object's parsed parent in the tree
:param .ParseContext context: This object's context
"""
__metaclass__ = abc.ABCMeta
def __init__(self, parent=None):
def __init__(self, context=None):
self._data = [] # type: List[object]
self._tabs = None
self.parent = parent
self.context = context
@classmethod
def parsing_hooks(cls):
"""Returns object types that this class should be able to `parse` recusrively.
The order of the objects indicates the order in which the parser should
try to parse each subitem.
:returns: A list of Parsable classes.
:rtype list:
"""
return (Block, Sentence, Statements)
def get_parent(self):
if self.context == None:
return None
return self.context.parent
@staticmethod
@abc.abstractmethod
@@ -76,29 +72,6 @@ class Parsable(object):
"""
raise NotImplementedError()
@abc.abstractmethod
def get_tabs(self):
""" Guess at the tabbing style of this parsed object, based on whitespace.
If this object is a leaf, it deducts the tabbing based on its own contents.
Other objects may guess by calling `get_tabs` recursively on child objects.
:returns: Guess at tabbing for this object. Should only return whitespace strings
that does not contain newlines.
:rtype str:
"""
raise NotImplementedError()
@abc.abstractmethod
def set_tabs(self, tabs=" "):
"""This tries to set and alter the tabbing of the current object to a desired
whitespace string. Primarily meant for objects that were constructed, so they
can conform to surrounding whitespace.
:param str tabs: A whitespace string (not containing newlines).
"""
raise NotImplementedError()
def dump(self, include_spaces=False):
""" Dumps back to pyparsing-like list tree. The opposite of `parse`.
@@ -113,16 +86,42 @@ class Parsable(object):
"""
return [elem.dump(include_spaces) for elem in self._data]
def dump_unspaced_list(self):
""" Dumps back to pyparsing-like list tree into an UnspacedList.
Use for compatibility with UnspacedList dependencies while migrating
to new parsing objects.
class Statements(Parsable):
""" A group or list of "Statements". A Statement is either a Block or a Sentence.
:returns: Pyparsing-like list tree.
:rtype :class:`.nginxparser.UnspacedList`:
"""
return nginxparser.UnspacedList(self.dump(True))
The underlying representation is simply a list of these Statement objects, with
def child_context(self, filename=None, cwd=None):
""" Spawn a child context. """
if self.context:
return self.context.child(self, filename=filename)
return ParseContext(parent=self, filename=filename, cwd=cwd)
def get_path(self):
""" TODO: document and test"""
if not self.context.parent or self.context.parent.context.filename != self.context.filename:
return None
parent_path = self.context.parent.get_path()
my_index = self.context.parent._data.index(self)
if parent_path:
return parent_path + [my_index]
return [my_index]
class Directives(Parsable):
""" A group or list of Directives.
The underlying representation is simply a list of other parsed objects, with
an extra `_trailing_whitespace` string to keep track of the whitespace that does not
precede any more statements.
"""
def __init__(self, parent=None):
super(Statements, self).__init__(parent)
def __init__(self, context=None):
super(Directives, self).__init__(context)
self._trailing_whitespace = None
# ======== Begin overridden functions
@@ -131,43 +130,32 @@ class Statements(Parsable):
def should_parse(lists):
return isinstance(lists, list)
def set_tabs(self, tabs=" "):
""" Sets the tabbing for this set of statements. Does this by calling `set_tabs`
on each of the child statements.
Then, if a parent is present, sets trailing whitespace to parent tabbing. This
is so that the trailing } of any Block that contains Statements lines up
with parent tabbing.
"""
for statement in self._data:
statement.set_tabs(tabs)
if self.parent is not None:
self._trailing_whitespace = "\n" + self.parent.get_tabs()
def parse(self, raw_list, add_spaces=False):
""" Parses a list of statements.
Expects all elements in `raw_list` to be parseable by `type(self).parsing_hooks`,
with an optional whitespace string at the last index of `raw_list`.
"""
if isinstance(raw_list, nginxparser.UnspacedList):
raw_list = raw_list.spaced
if not isinstance(raw_list, list):
raise errors.MisconfigurationError("Statements parsing expects a list!")
# If there's a trailing whitespace in the list of statements, keep track of it.
if raw_list and isinstance(raw_list[-1], six.string_types) and raw_list[-1].isspace():
self._trailing_whitespace = raw_list[-1]
raw_list = raw_list[:-1]
self._data = [parse_raw(elem, self, add_spaces) for elem in raw_list]
def get_tabs(self):
""" Takes a guess at the tabbing of all contained Statements by retrieving the
tabbing of the first Statement."""
if self._data:
return self._data[0].get_tabs()
return ""
raise errors.MisconfigurationError("Directives parsing expects a list!")
# If there's a trailing whitespace in the list of statements, keep track of them
if raw_list:
i = -1
while len(raw_list) >= -i and isinstance(raw_list[i], six.string_types) and raw_list[i].isspace():
i -= 1
self._trailing_whitespace = "".join(raw_list[i+1:])
raw_list = raw_list[:i+1]
# Create parsing objects first, then parse. Then references to parent
# data exist while we parse the child objects.
self._data = [_choose_parser(self.child_context(), elem) for elem in raw_list]
for i, elem in enumerate(raw_list):
self._data[i].parse(elem, add_spaces)
def dump(self, include_spaces=False):
""" Dumps this object by first dumping each statement, then appending its
trailing whitespace (if `include_spaces` is set) """
data = super(Statements, self).dump(include_spaces)
data = super(Directives, self).dump(include_spaces)
if include_spaces and self._trailing_whitespace is not None:
return data + [self._trailing_whitespace]
return data
@@ -180,6 +168,37 @@ class Statements(Parsable):
# ======== End overridden functions
def update_directive(self, statement, index):
""" upd8
"""
self._data[index] = statement
if index + 1 >= len(self._data) or not _is_certbot_comment(self._data[index+1]):
self._data.insert(index+1, _certbot_comment(self.context))
def find_directive(self, match_func):
for i, elem in enumerate(self._data):
if isinstance(elem, Sentence) and match_func(elem):
return i
return -1
def add_directive(self, statement, insert_at_top=False):
""" Takes in a parse obj
"""
index = 0
if insert_at_top:
self._data.insert(0, statement)
else:
index = len(self._data)
self._data.append(statement)
if not _is_comment(statement):
self._data.insert(index+1, _certbot_comment(self.context))
def get_type(self, match_type):
""" TODO
"""
return self.iterate(expanded=True,
match=lambda elem: isinstance(elem, match_type))
def _space_list(list_):
""" Inserts whitespace between adjacent non-whitespace tokens. """
@@ -220,31 +239,15 @@ class Sentence(Parsable):
def iterate(self, expanded=False, match=None):
""" Simply yields itself. """
if match is None or match(self):
if (match is None) or match(self):
yield self
def set_tabs(self, tabs=" "):
""" Sets the tabbing on this sentence. Inserts a newline and `tabs` at the
beginning of `self._data`. """
if self._data[0].isspace():
return
self._data.insert(0, "\n" + tabs)
def dump(self, include_spaces=False):
""" Dumps this sentence. If include_spaces is set, includes whitespace tokens."""
if not include_spaces:
return self.words
return self._data
def get_tabs(self):
""" Guesses at the tabbing of this sentence. If the first element is whitespace,
returns the whitespace after the rightmost newline in the string. """
first = self._data[0]
if not first.isspace():
return ""
rindex = first.rfind("\n")
return first[rindex+1:]
# ======== End overridden functions
@property
@@ -252,6 +255,9 @@ class Sentence(Parsable):
""" Iterates over words, but without spaces. Like Unspaced List. """
return [word.strip("\"\'") for word in self._data if not word.isspace()]
def __len__(self):
return len(self.words)
def __getitem__(self, index):
return self.words[index]
@@ -270,8 +276,8 @@ class Block(Parsable):
names = ["block", " ", "name", " "]
contents = [["\n ", "content", " ", "1"], ["\n ", "content", " ", "2"], "\n"]
"""
def __init__(self, parent=None):
super(Block, self).__init__(parent)
def __init__(self, context=None):
super(Block, self).__init__(context)
self.names = None # type: Sentence
self.contents = None # type: Block
@@ -279,7 +285,7 @@ class Block(Parsable):
def should_parse(lists):
""" Returns True if `lists` can be parseable as a `Block`-- that is,
it's got a length of 2, the first element is a `Sentence` and the second can be
a `Statements`.
a `Directives`.
:param list lists: The raw unparsed list to check.
@@ -287,12 +293,6 @@ class Block(Parsable):
return isinstance(lists, list) and len(lists) == 2 and \
Sentence.should_parse(lists[0]) and isinstance(lists[1], list)
def set_tabs(self, tabs=" "):
""" Sets tabs by setting equivalent tabbing on names, then adding tabbing
to contents."""
self.names.set_tabs(tabs)
self.contents.set_tabs(tabs + " ")
def iterate(self, expanded=False, match=None):
""" Iterator over self, and if expanded is set, over its contents. """
if match is None or match(self):
@@ -314,18 +314,14 @@ class Block(Parsable):
if not Block.should_parse(raw_list):
raise errors.MisconfigurationError("Block parsing expects a list of length 2. "
"First element should be a list of string types (the bloc names), "
"and second should be another list of statements (the bloc content).")
self.names = Sentence(self)
"and second should be another list of directives (the bloc content).")
self.names = Sentence(self.child_context())
self.contents = Directives(self.child_context())
self._data = [self.names, self.contents]
if add_spaces:
raw_list[0].append(" ")
self.names.parse(raw_list[0], add_spaces)
self.contents = Statements(self)
self.contents.parse(raw_list[1], add_spaces)
self._data = [self.names, self.contents]
def get_tabs(self):
""" Guesses tabbing by retrieving tabbing guess of self.names. """
return self.names.get_tabs()
def _is_comment(parsed_obj):
""" Checks whether parsed_obj is a comment.
@@ -356,34 +352,34 @@ def _is_certbot_comment(parsed_obj):
return False
return True
def _certbot_comment(parent, preceding_spaces=4):
def _certbot_comment(context, preceding_spaces=4):
""" A "Managed by Certbot" comment.
:param int preceding_spaces: Number of spaces between the end of the previous
statement and the comment.
:returns: Sentence containing the comment.
:rtype: .Sentence
"""
result = Sentence(parent)
result = Sentence(context)
result.parse([" " * preceding_spaces] + COMMENT_BLOCK)
return result
def _choose_parser(parent, list_):
""" Choose a parser from type(parent).parsing_hooks, depending on whichever hook
def _choose_parser(context, list_):
""" Choose a parser from type(context).parsing_hooks, depending on whichever hook
returns True first. """
hooks = Parsable.parsing_hooks()
if parent:
hooks = type(parent).parsing_hooks()
hooks = ParseContext.parsing_hooks()
if context:
hooks = type(context).parsing_hooks()
for type_ in hooks:
if type_.should_parse(list_):
return type_(parent)
return type_(context)
raise errors.MisconfigurationError(
"None of the parsing hooks succeeded, so we don't know how to parse this set of lists.")
def parse_raw(lists_, parent=None, add_spaces=False):
def parse_raw(lists_, context=None, add_spaces=False):
""" Primary parsing factory function.
:param list lists_: raw lists from pyparsing to parse.
:param .Parent parent: The parent containing this object.
:param .ParseContext context: The context of this object.
:param bool add_spaces: Whether to pass add_spaces to the parser.
:returns .Parsable: The parsed object.
@@ -391,6 +387,38 @@ def parse_raw(lists_, parent=None, add_spaces=False):
:raises errors.MisconfigurationError: If no parsing hook passes, and we can't
determine which type to parse the raw lists into.
"""
parser = _choose_parser(parent, lists_)
if context is None:
context = ParseContext()
parser = _choose_parser(context, lists_)
parser.parse(lists_, add_spaces)
return parser
class ParseContext(object):
""" Context information held by parsed objects.
:param .Parsable parent: The parent object containing the associated object.
:param str filename: relative file path that the associated object was parsed from
:param str cwd: current working directory/root of the parsing files
"""
__metaclass__ = abc.ABCMeta
def __init__(self, parent=None, filename=None, cwd=None):
self.parent = parent
self.filename = filename
self.cwd = cwd
def child(self, parent, filename=None):
""" Returns Context with all fields inherited, except parent points to this object.
"""
return ParseContext(parent, filename if filename else self.filename, self.cwd)
@staticmethod
def parsing_hooks():
"""Returns object types that this class should be able to `parse` recusrively.
The order of the objects indicates the order in which the parser should
try to parse each subitem.
:returns: A list of Parsable classes.
:rtype list:
"""
return (Block, Sentence, Directives)

View File

@@ -177,6 +177,7 @@ class NginxConfiguratorTest(util.NginxTest):
if name == "ipv6.com":
self.assertTrue(vhost.ipv6_enabled())
# Make sure that we have SSL enabled also for IPv6 addr
print vhost
self.assertTrue(
any([True for x in vhost.addrs if x.ssl and x.ipv6]))
@@ -494,23 +495,23 @@ class NginxConfiguratorTest(util.NginxTest):
self.assertEqual(
[[['server'], [
['server_name', '.example.com'],
['server_name', 'example.*'], [],
['server_name', 'example.*'],
['listen', '5001', 'ssl'], ['#', ' managed by Certbot'],
['ssl_certificate', 'example/fullchain.pem'], ['#', ' managed by Certbot'],
['ssl_certificate_key', 'example/key.pem'], ['#', ' managed by Certbot'],
['include', self.config.mod_ssl_conf], ['#', ' managed by Certbot'],
['ssl_dhparam', self.config.ssl_dhparams], ['#', ' managed by Certbot'],
[], []]],
]],
[['server'], [
[['if', '($host', '=', 'www.example.com)'], [
['return', '301', 'https://$host$request_uri']]],
['#', ' managed by Certbot'], [],
['#', ' managed by Certbot'],
['listen', '69.50.225.155:9000'],
['listen', '127.0.0.1'],
['server_name', '.example.com'],
['server_name', 'example.*'],
['return', '404'], ['#', ' managed by Certbot'], [], [], []]]],
generated_conf)
['return', '404'], ['#', ' managed by Certbot'], ]]][0],
generated_conf[0])
def test_split_for_headers(self):
example_conf = self.config.parser.abs_path('sites-enabled/example.com')
@@ -525,22 +526,21 @@ class NginxConfiguratorTest(util.NginxTest):
self.assertEqual(
[[['server'], [
['server_name', '.example.com'],
['server_name', 'example.*'], [],
['server_name', 'example.*'],
['listen', '5001', 'ssl'], ['#', ' managed by Certbot'],
['ssl_certificate', 'example/fullchain.pem'], ['#', ' managed by Certbot'],
['ssl_certificate_key', 'example/key.pem'], ['#', ' managed by Certbot'],
['include', self.config.mod_ssl_conf], ['#', ' managed by Certbot'],
['ssl_dhparam', self.config.ssl_dhparams], ['#', ' managed by Certbot'],
[], [],
['add_header', 'Strict-Transport-Security', '"max-age=31536000"', 'always'],
['#', ' managed by Certbot'],
[], []]],
]],
[['server'], [
['listen', '69.50.225.155:9000'],
['listen', '127.0.0.1'],
['server_name', '.example.com'],
['server_name', 'example.*'],
[], [], []]]],
]]],
generated_conf)
def test_http_header_hsts(self):
@@ -633,6 +633,7 @@ class NginxConfiguratorTest(util.NginxTest):
default_conf = self.config.parser.abs_path('sites-enabled/default')
foo_conf = self.config.parser.abs_path('foo.conf')
del self.config.parser.parsed[foo_conf][2][1][0][1][0] # remove default_server
self.config.parser.get_vhost(foo_conf).raw_obj.parse(self.config.parser.parsed[foo_conf][2][1][0])
self.config.version = (1, 3, 1)
self.config.deploy_cert(
@@ -687,6 +688,7 @@ class NginxConfiguratorTest(util.NginxTest):
foo_conf = self.config.parser.abs_path('foo.conf')
del self.config.parser.parsed[default_conf][0][1][0]
del self.config.parser.parsed[default_conf][0][1][0]
self.config.parser.get_vhost(default_conf).raw_obj.parse(self.config.parser.parsed[default_conf][0])
self.config.version = (1, 3, 1)
self.config.deploy_cert(
@@ -721,6 +723,8 @@ class NginxConfiguratorTest(util.NginxTest):
del self.config.parser.parsed[default_conf][0][1][0]
del self.config.parser.parsed[default_conf][0][1][0]
del self.config.parser.parsed[foo_conf][2][1][0][1][0]
self.config.parser.get_vhost(foo_conf).raw_obj.parse(self.config.parser.parsed[foo_conf][2][1][0])
self.config.parser.get_vhost(default_conf).raw_obj.parse(self.config.parser.parsed[default_conf][0])
self.config.version = (1, 3, 1)
self.assertRaises(errors.MisconfigurationError, self.config.deploy_cert,
@@ -736,6 +740,7 @@ class NginxConfiguratorTest(util.NginxTest):
def test_deploy_no_match_multiple_defaults_ok(self):
foo_conf = self.config.parser.abs_path('foo.conf')
self.config.parser.parsed[foo_conf][2][1][0][1][0][1] = '*:5001'
self.config.parser.get_vhost(foo_conf).raw_obj.parse(self.config.parser.parsed[foo_conf][2][1][0])
self.config.version = (1, 3, 1)
self.config.deploy_cert("www.nomatch.com", "example/cert.pem", "example/key.pem",
"example/chain.pem", "example/fullchain.pem")
@@ -744,6 +749,7 @@ class NginxConfiguratorTest(util.NginxTest):
default_conf = self.config.parser.abs_path('sites-enabled/default')
foo_conf = self.config.parser.abs_path('foo.conf')
del self.config.parser.parsed[foo_conf][2][1][0][1][0] # remove default_server
self.config.parser.get_vhost(foo_conf).raw_obj.parse(self.config.parser.parsed[foo_conf][2][1][0])
self.config.version = (1, 3, 1)
self.config.deploy_cert(

View File

@@ -53,7 +53,7 @@ class ParsingHooksTest(unittest.TestCase):
self.assertFalse(Block.should_parse([['block_name'], 'lol']))
self.assertTrue(Block.should_parse([['block_name'], ['hi', []]]))
self.assertTrue(Block.should_parse([['hello'], []]))
self.assertTrue(Block.should_parse([['block_name'], [['many'], ['statements'], 'here']]))
self.assertTrue(Block.should_parse([['block_name'], [['many'], ['directives'], 'here']]))
self.assertTrue(Block.should_parse([['if', ' ', '(whatever)'], ['hi']]))
def test_parse_raw(self):
@@ -71,7 +71,7 @@ class ParsingHooksTest(unittest.TestCase):
fake_parser1.not_called()
fake_parser2.called_once()
@mock.patch("certbot_nginx.parser_obj.Parsable.parsing_hooks")
@mock.patch("certbot_nginx.parser_obj.ParseContext.parsing_hooks")
def test_parse_raw_no_match(self, parsing_hooks):
from certbot import errors
fake_parser1 = mock.Mock()
@@ -119,22 +119,6 @@ class SentenceTest(unittest.TestCase):
for i, sentence in enumerate(self.sentence.iterate()):
self.assertEqual(sentence.dump(), expected[i])
def test_set_tabs(self):
self.sentence.parse(['tabs', 'pls'], add_spaces=True)
self.sentence.set_tabs()
self.assertEqual(self.sentence.dump(True)[0], '\n ')
self.sentence.parse(['tabs', 'pls'], add_spaces=True)
def test_get_tabs(self):
self.sentence.parse(['no', 'tabs'])
self.assertEqual(self.sentence.get_tabs(), '')
self.sentence.parse(['\n \n ', 'tabs'])
self.assertEqual(self.sentence.get_tabs(), ' ')
self.sentence.parse(['\n\t ', 'tabs'])
self.assertEqual(self.sentence.get_tabs(), '\t ')
self.sentence.parse(['\n\t \n', 'tabs'])
self.assertEqual(self.sentence.get_tabs(), '')
class BlockTest(unittest.TestCase):
def setUp(self):
from certbot_nginx.parser_obj import Block
@@ -179,21 +163,10 @@ class BlockTest(unittest.TestCase):
self.assertRaises(errors.MisconfigurationError, self.bloc.parse, ['lol'])
self.assertRaises(errors.MisconfigurationError, self.bloc.parse, ['fake', 'news'])
def test_set_tabs(self):
self.bloc.set_tabs()
self.assertEqual(self.bloc.names.dump(True)[0], '\n ')
for elem in self.bloc.contents.dump(True)[:-1]:
self.assertEqual(elem[0], '\n ')
self.assertEqual(self.bloc.contents.dump(True)[-1][0], '\n')
def test_get_tabs(self):
self.bloc.parse([[' \n \t', 'lol'], []])
self.assertEqual(self.bloc.get_tabs(), ' \t')
class StatementsTest(unittest.TestCase):
class DirectivesTest(unittest.TestCase):
def setUp(self):
from certbot_nginx.parser_obj import Statements
self.statements = Statements(None)
from certbot_nginx.parser_obj import Directives
self.directives = Directives(None)
self.raw = [
['sentence', 'one'],
['sentence', 'two'],
@@ -206,47 +179,24 @@ class StatementsTest(unittest.TestCase):
'\n\n'
]
def test_set_tabs(self):
self.statements.parse(self.raw)
self.statements.set_tabs()
for statement in self.statements.iterate():
self.assertEqual(statement.dump(True)[0], '\n ')
def test_set_tabs_with_parent(self):
# Trailing whitespace should inherit from parent tabbing.
self.statements.parse(self.raw)
self.statements.parent = mock.Mock()
self.statements.parent.get_tabs.return_value = '\t\t'
self.statements.set_tabs()
for statement in self.statements.iterate():
self.assertEqual(statement.dump(True)[0], '\n ')
self.assertEqual(self.statements.dump(True)[-1], '\n\t\t')
def test_get_tabs(self):
self.raw[0].insert(0, '\n \n \t')
self.statements.parse(self.raw)
self.assertEqual(self.statements.get_tabs(), ' \t')
self.statements.parse([])
self.assertEqual(self.statements.get_tabs(), '')
def test_parse_with_added_spaces(self):
self.statements.parse(self.raw, add_spaces=True)
self.assertEqual(self.statements.dump(True)[0], ['sentence', ' ', 'one'])
self.directives.parse(self.raw, add_spaces=True)
self.assertEqual(self.directives.dump(True)[0], ['sentence', ' ', 'one'])
def test_parse_bad_list_raises_error(self):
from certbot import errors
self.assertRaises(errors.MisconfigurationError, self.statements.parse, 'lol not a list')
self.assertRaises(errors.MisconfigurationError, self.directives.parse, 'lol not a list')
def test_parse_hides_trailing_whitespace(self):
self.statements.parse(self.raw + ['\n\n '])
self.assertTrue(isinstance(self.statements.dump()[-1], list))
self.assertTrue(self.statements.dump(True)[-1].isspace())
self.assertEqual(self.statements.dump(True)[-1], '\n\n ')
self.directives.parse(self.raw + ['\n\n '])
self.assertTrue(isinstance(self.directives.dump()[-1], list))
self.assertTrue(self.directives.dump(True)[-1].isspace())
self.assertEqual(self.directives.dump(True)[-1], '\n\n ')
def test_iterate(self):
self.statements.parse(self.raw)
self.directives.parse(self.raw)
expected = [['sentence', 'one'], ['sentence', 'two']]
for i, elem in enumerate(self.statements.iterate(match=lambda x: 'sentence' in x)):
for i, elem in enumerate(self.directives.iterate(match=lambda x: 'sentence' in x)):
self.assertEqual(expected[i], elem.dump())
if __name__ == "__main__":

View File

@@ -365,7 +365,6 @@ class NginxParserTest(util.NginxTest): #pylint: disable=too-many-public-methods
self.assertEqual(block.spaced, [
["\n", "a", " ", "b", "\n"],
COMMENT_BLOCK,
"\n",
["c", " ", "d"],
COMMENT_BLOCK,
["\n", "e", " ", "f"]])