Compare commits

...

1 Commits

Author SHA1 Message Date
Joona Hoikkala
a5d1509434 ParserNode implementation of VirtualHost creation 2019-08-14 00:04:40 +03:00
4 changed files with 411 additions and 4 deletions

View File

@@ -0,0 +1,192 @@
""" Tests for ParserNode interface """
from certbot_apache import interfaces
from acme.magic_typing import Dict, Tuple # pylint: disable=unused-import, no-name-in-module
class AugeasCommentNode(interfaces.CommentNode):
""" Augeas implementation of CommentNode interface """
ancestor = None
comment = ""
dirty = False
_metadata = dict() # type: Dict[str, object]
def __init__(self, comment, ancestor=None):
self.comment = comment
self.ancestor = ancestor
def save(self, msg): # pragma: no cover
pass
# Apache specific functionality
def get_metadata(self, key):
""" Returns a metadata object
:param str key: Metadata object name to return
:returns: Requested metadata object
"""
try:
return self._metadata[key]
except KeyError:
return None
class AugeasDirectiveNode(interfaces.DirectiveNode):
""" Augeas implementation of DirectiveNode interface """
ancestor = None
parameters = tuple() # type: Tuple[str, ...]
dirty = False
enabled = True
name = ""
_metadata = dict() # type: Dict[str, object]
def __init__(self, name, parameters=tuple(), ancestor=None):
self.name = name
self.parameters = parameters
self.ancestor = ancestor
def save(self, msg): # pragma: no cover
pass
def set_parameters(self, parameters): # pragma: no cover
self.parameters = tuple("CERTBOT_PASS_ASSERT")
# Apache specific functionality
def get_filename(self):
"""Returns the filename where this directive exists on disk
:returns: File path to this node.
:rtype: str
"""
# Following is the real implementation when everything else is in place:
# return apache_util.get_file_path(
# self.parser.aug.get("/augeas/files%s/path" % apache_util.get_file_path(path)))
return "CERTBOT_PASS_ASSERT"
def get_metadata(self, key):
""" Returns a metadata object
:param str key: Metadata object name to return
:returns: Requested metadata object
"""
try:
return self._metadata[key]
except KeyError:
return None
def has_parameter(self, parameter, position=None):
"""Checks if this ParserNode object has a supplied parameter. This check
is case insensitive.
:param str parameter: Parameter value to look for
:param position: Optional explicit position of parameter to look for
:returns: True if parameter is found
:rtype: bool
"""
if position != None:
return parameter.lower() == self.parameters[position].lower()
for param in self.parameters:
if param.lower() == parameter.lower():
return True
return False
class AugeasBlockNode(interfaces.BlockNode):
""" Augeas implementation of BlockNode interface """
ancestor = None
parameters = tuple() # type: Tuple[str, ...]
children = tuple() # type: Tuple[interfaces.ParserNode, ...]
dirty = False
enabled = True
name = ""
_metadata = dict() # type: Dict[str, object]
def __init__(self, name, parameters=tuple(), ancestor=None):
self.name = name
self.parameters = parameters
self.ancestor = ancestor
def save(self, msg): # pragma: no cover
pass
def add_child_block(self, name, parameters=None, position=None): # pragma: no cover
new_block = AugeasBlockNode("CERTBOT_PASS_ASSERT", ancestor=self)
self.children += (new_block,)
return new_block
def add_child_directive(self, name, parameters=None, position=None): # pragma: no cover
new_dir = AugeasDirectiveNode("CERTBOT_PASS_ASSERT", ancestor=self)
self.children += (new_dir,)
return new_dir
def add_child_comment(self, comment="", position=None): # pragma: no cover
new_comment = AugeasCommentNode("CERTBOT_PASS_ASSERT", ancestor=self)
self.children += (new_comment,)
return new_comment
def find_blocks(self, name, exclude=True): # pragma: no cover
return [AugeasBlockNode("CERTBOT_PASS_ASSERT", ancestor=self)]
def find_directives(self, name, exclude=True): # pragma: no cover
return [AugeasDirectiveNode("CERTBOT_PASS_ASSERT", ancestor=self)]
def find_comments(self, comment, exact=False): # pragma: no cover
return [AugeasCommentNode("CERTBOT_PASS_ASSERT", ancestor=self)]
def delete_child(self, child): # pragma: no cover
pass
def set_parameters(self, parameters): # pragma: no cover
self.parameters = tuple("CERTBOT_PASS_ASSERT")
def unsaved_files(self): # pragma: no cover
return ["CERTBOT_PASS_ASSERT"]
# Apache specific functionality
def get_filename(self):
"""Returns the filename where this directive exists on disk
:returns: File path to this node.
:rtype: str
"""
# Following is the real implementation when everything else is in place:
# return apache_util.get_file_path(
# self.parser.aug.get("/augeas/files%s/path" %
# apache_util.get_file_path(self.get_metadata("augeas_path")))
return "CERTBOT_PASS_ASSERT"
def get_metadata(self, key):
""" Returns a metadata object
:param str key: Metadata object name to return
:returns: Requested metadata object
"""
try:
return self._metadata[key]
except KeyError:
return None
def has_parameter(self, parameter, position=None):
"""Checks if this ParserNode object has a supplied parameter. This check
is case insensitive.
:param str parameter: Parameter value to look for
:param position: Optional explicit position of parameter to look for
:returns: True if parameter is found
:rtype: bool
"""
if position != None:
return parameter.lower() == self.parameters[position].lower()
for param in self.parameters:
if param.lower() == parameter.lower():
return True
return False

View File

@@ -810,6 +810,28 @@ class ApacheConfigurator(common.Installer):
return (servername, serveraliases)
def _populate_vhost_names_v2(self, vhost):
"""Helper function that populates the VirtualHost names.
:param host: In progress vhost whose names will be added
:type host: :class:`~certbot_apache.obj.VirtualHost`
"""
servername_match = vhost.node.find_directives("ServerName",
exclude=False)
serveralias_match = vhost.node.find_directives("ServerAlias",
exclude=False)
if servername_match:
servername = servername_match[-1].parameters[-1]
if not vhost.modmacro:
for alias in serveralias_match:
for serveralias in alias.parameters:
vhost.aliases.add(serveralias)
vhost.name = servername
def _add_servernames(self, host):
"""Helper function for get_virtual_hosts().
@@ -871,6 +893,52 @@ class ApacheConfigurator(common.Installer):
self._add_servernames(vhost)
return vhost
def _create_vhost_v2(self, node):
"""Used by get_virtual_hosts to create vhost objects using ParserNode
interfaces.
:param interfaces.BlockNode node: The BlockNode object of VirtualHost block
:returns: newly created vhost
:rtype: :class:`~certbot_apache.obj.VirtualHost`
"""
addrs = set()
for param in node.parameters:
addrs.add(obj.Addr.fromstring(param))
is_ssl = False
sslengine = node.find_directives("SSLEngine")
if sslengine:
for directive in sslengine:
# TODO: apache-parser-v2
# This search should be made wiser. (using other identificators)
if directive.has_parameter("on", 0):
is_ssl = True
# "SSLEngine on" might be set outside of <VirtualHost>
# Treat vhosts with port 443 as ssl vhosts
for addr in addrs:
if addr.get_port() == "443":
is_ssl = True
filename = node.get_filename()
if filename is None:
return None
macro = False
if node.find_directives("Macro"):
macro = True
vhost_enabled = self.parser.parsed_in_original(filename)
vhost = obj.VirtualHost(filename, node.get_metadata("augeas_path"),
addrs, is_ssl, vhost_enabled, modmacro=macro,
node=node)
self._populate_vhost_names_v2(vhost)
return vhost
def get_virtual_hosts(self):
"""Returns list of virtual hosts found in the Apache configuration.

View File

@@ -124,7 +124,7 @@ class VirtualHost(object): # pylint: disable=too-few-public-methods
strip_name = re.compile(r"^(?:.+://)?([^ :$]*)")
def __init__(self, filep, path, addrs, ssl, enabled, name=None,
aliases=None, modmacro=False, ancestor=None):
aliases=None, modmacro=False, ancestor=None, node=None):
# pylint: disable=too-many-arguments
"""Initialize a VH."""
@@ -137,6 +137,7 @@ class VirtualHost(object): # pylint: disable=too-few-public-methods
self.enabled = enabled
self.modmacro = modmacro
self.ancestor = ancestor
self.node = node
def get_names(self):
"""Return a set of all names."""

View File

@@ -2,10 +2,15 @@
import unittest
from acme.magic_typing import Optional, Tuple # pylint: disable=unused-import, no-name-in-module
import mock
from acme.magic_typing import Dict, Tuple # pylint: disable=unused-import, no-name-in-module
from certbot_apache import augeasparser
from certbot_apache import interfaces
from certbot_apache.tests import util
class DummyCommentNode(interfaces.CommentNode):
@@ -73,14 +78,155 @@ class DummyBlockNode(interfaces.BlockNode):
pass
class ParserNodeTest(unittest.TestCase):
"""Dummy placeholder test case for ParserNode interfaces"""
class ParserNodeTest(util.ApacheTest):
"""Test cases for ParserNode interface"""
def __init__(self, *args, **kwargs):
super(ParserNodeTest, self).__init__(*args, **kwargs)
self.mock_nodes = dict() # type: Dict[str, interfaces.ParserNode]
def setUp(self): # pylint: disable=arguments-differ
super(ParserNodeTest, self).setUp()
self.config = util.get_apache_configurator(
self.config_path, self.vhost_path, self.config_dir, self.work_dir)
self.vh_truth = util.get_vh_truth(
self.temp_dir, "debian_apache_2_4/multiple_vhosts")
def test_dummy(self):
dummyblock = DummyBlockNode()
dummydirective = DummyDirectiveNode()
dummycomment = DummyCommentNode()
def _create_mock_vhost_nodes(self, servername, serveraliases, addrs):
"""Create a mock VirtualHost nodes"""
nodes = {
"VirtualHost": augeasparser.AugeasBlockNode("VirtualHost", tuple(addrs)),
"ServerName": augeasparser.AugeasDirectiveNode("ServerName",
(servername,)),
"ServerAlias": augeasparser.AugeasDirectiveNode("ServerAlias",
tuple(serveraliases)),
"Macro": augeasparser.AugeasDirectiveNode("Macro", ("variable", "value",)),
"SSLEngine": augeasparser.AugeasDirectiveNode("SSLEngine", ("on",))
}
return nodes
def mock_find_directives(self, name, exclude=True): # pylint: disable=unused-argument
"""
Mocks BlockNode.find_directives() and returns values defined in class
variable self.mock_nodes, set by the test case
"""
try:
return self.mock_nodes[name]
except KeyError:
return []
def test_create_vhost_v2_nonssl(self):
nodes = self._create_mock_vhost_nodes("example.com",
["a1.example.com", "a2.example.com"],
["*:80"])
nodes["VirtualHost"].find_directives = self.mock_find_directives
self.mock_nodes = {"ServerName": [nodes["ServerName"]],
"ServerAlias": [nodes["ServerAlias"]]}
vhost = self.config._create_vhost_v2(nodes["VirtualHost"]) # pylint: disable=protected-access
self.assertEqual(vhost.name, "example.com")
self.assertTrue("a1.example.com" in vhost.aliases)
self.assertTrue("a2.example.com" in vhost.aliases)
self.assertEqual(len(vhost.aliases), 2)
self.assertEqual(len(vhost.addrs), 1)
self.assertFalse(vhost.ssl)
self.assertFalse(vhost.modmacro)
def test_create_vhost_v2_macro(self):
nodes = self._create_mock_vhost_nodes("example.com",
["a1.example.com", "a2.example.com"],
["*:80"])
nodes["VirtualHost"].find_directives = self.mock_find_directives
self.mock_nodes = {"ServerName": [nodes["ServerName"]],
"ServerAlias": [nodes["ServerAlias"]],
"Macro": [nodes["Macro"]]}
vhost = self.config._create_vhost_v2(nodes["VirtualHost"]) # pylint: disable=protected-access
self.assertEqual(vhost.name, None)
self.assertEqual(vhost.aliases, set())
self.assertFalse(vhost.ssl)
self.assertTrue(vhost.modmacro)
def test_create_vhost_v2_ssl_port(self):
nodes = self._create_mock_vhost_nodes("example.com",
["a1.example.com", "a2.example.com"],
["*:443"])
nodes["VirtualHost"].find_directives = self.mock_find_directives
self.mock_nodes = {"ServerName": [nodes["ServerName"]],
"ServerAlias": [nodes["ServerAlias"]]}
vhost = self.config._create_vhost_v2(nodes["VirtualHost"]) # pylint: disable=protected-access
self.assertTrue(vhost.ssl)
self.assertFalse(vhost.modmacro)
def test_create_vhost_v2_sslengine(self):
nodes = self._create_mock_vhost_nodes("example.com",
["a1.example.com", "a2.example.com"],
["*:80"])
nodes["VirtualHost"].find_directives = self.mock_find_directives
self.mock_nodes = {"ServerName": [nodes["ServerName"]],
"ServerAlias": [nodes["ServerAlias"]],
"SSLEngine": [nodes["SSLEngine"]]}
vhost = self.config._create_vhost_v2(nodes["VirtualHost"]) # pylint: disable=protected-access
self.assertTrue(vhost.ssl)
self.assertFalse(vhost.modmacro)
def test_create_vhost_v2_no_filename(self):
nodes = self._create_mock_vhost_nodes("example.com",
["a1.example.com", "a2.example.com"],
["*:80"])
nodes["VirtualHost"].find_directives = self.mock_find_directives
self.mock_nodes = {"ServerName": [nodes["ServerName"]],
"ServerAlias": [nodes["ServerAlias"]],
"SSLEngine": [nodes["SSLEngine"]]}
filename = "certbot_apache.augeasparser.AugeasBlockNode.get_filename"
with mock.patch(filename) as mock_filename:
mock_filename.return_value = None
vhost = self.config._create_vhost_v2(nodes["VirtualHost"]) # pylint: disable=protected-access
self.assertEqual(vhost, None)
def test_comment_node_creation(self):
comment = augeasparser.AugeasCommentNode("This is a comment")
comment._metadata["augeas_path"] = "/whatever" # pylint: disable=protected-access
self.assertEqual(comment.get_metadata("augeas_path"), "/whatever")
self.assertEqual(comment.get_metadata("something_else"), None)
self.assertEqual(comment.comment, "This is a comment")
def test_directive_node_creation(self):
directive = augeasparser.AugeasDirectiveNode("DIRNAME", ("p1", "p2",))
directive._metadata["augeas_path"] = "/whatever" # pylint: disable=protected-access
self.assertEqual(directive.get_metadata("augeas_path"), "/whatever")
self.assertEqual(directive.get_metadata("something_else"), None)
self.assertEqual(directive.name, "DIRNAME")
self.assertEqual(directive.parameters, ("p1", "p2",))
self.assertTrue(directive.has_parameter("P1", 0))
self.assertFalse(directive.has_parameter("P2", 0))
self.assertFalse(directive.has_parameter("P3"))
self.assertTrue(directive.has_parameter("P2"))
self.assertEqual(directive.get_filename(), "CERTBOT_PASS_ASSERT")
def test_block_node_creation(self):
block = augeasparser.AugeasBlockNode("BLOCKNAME", ("first", "SECOND",))
block._metadata["augeas_path"] = "/whatever" # pylint: disable=protected-access
self.assertEqual(block.get_metadata("augeas_path"), "/whatever")
self.assertEqual(block.get_metadata("something_else"), None)
self.assertEqual(block.name, "BLOCKNAME")
self.assertEqual(block.parameters, ("first", "SECOND",))
self.assertFalse(block.has_parameter("second", 0))
self.assertFalse(block.has_parameter("SECOND", 0))
self.assertFalse(block.has_parameter("third"))
self.assertTrue(block.has_parameter("FIRST"))
self.assertEqual(block.get_filename(), "CERTBOT_PASS_ASSERT")
if __name__ == "__main__":
unittest.main() # pragma: no cover