#!/usr/bin/env python
# -*- coding: utf-8 -*-
# This file is part of the Wapiti project (http://wapiti.sourceforge.net)
# Copyright (C) 2008-2013 Nicolas Surribas
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
import sys
import getopt
import os
import urlparse
import time

BASE_DIR = None
WAPITI_VERSION = "Wapiti 2.3.0"

if hasattr(sys, "frozen"):
    # For py2exe
    CONF_DIR = os.path.join(os.path.dirname(unicode(sys.executable, sys.getfilesystemencoding())), "data")
    from wapitiCore.language.language import Language
else:
    parent_dir = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir))
    if os.path.exists(os.path.join(parent_dir, "wapitiCore")):
        sys.path.append(parent_dir)
    from wapitiCore.language.language import Language
    CONF_DIR = os.path.dirname(sys.modules['wapitiCore'].__file__)


lan = Language()
lan.configure()
from wapitiCore.net import HTTP, lswww
from wapitiCore.file.reportgeneratorsxmlparser import ReportGeneratorsXMLParser
from wapitiCore.file.vulnerabilityxmlparser import VulnerabilityXMLParser
from wapitiCore.file.anomalyxmlparser import AnomalyXMLParser
from wapitiCore.net.crawlerpersister import CrawlerPersister


class InvalidOptionValue(Exception):
    def __init__(self, opt_name, opt_value):
        self.opt_name = opt_name
        self.opt_value = opt_value

    def __str__(self):
        return _("Invalid argument for option {0} : {1}").format(self.opt_name, self.opt_value)


class Wapiti(object):
    """This class parse the options from the command line and set the modules and the HTTP engine accordingly.
    Launch wapiti without arguments or with the "-h" option for more informations."""

    target_url = None
    target_scope = "folder"
    urls = {}
    forms = []

    color = 0
    verbose = 0

    reportGeneratorType = "html"
    REPORT_DIR = "report"
    REPORT_FILE = "vulnerabilities"
    HOME_DIR = os.getenv('HOME') or os.getenv('USERPROFILE')
    COPY_REPORT_DIR = os.path.join(HOME_DIR, ".wapiti", "generated_report")
    outputFile = ""

    options = ""

    http_engine = None
    myls = None
    reportGen = None

    attacks = []

    def __init__(self, root_url):
        self.target_url = root_url
        server = urlparse.urlparse(root_url).netloc
        self.http_engine = HTTP.HTTP(server)
        self.myls = lswww.lswww(root_url, http_engine=self.http_engine)
        self.xmlRepGenParser = ReportGeneratorsXMLParser()
        self.xmlRepGenParser.parse(os.path.join(CONF_DIR, "config", "reports", "generators.xml"))

    def __initReport(self):
        for repGenInfo in self.xmlRepGenParser.getReportGenerators():
            if self.reportGeneratorType.lower() == repGenInfo.getKey():
                self.reportGen = repGenInfo.createInstance()
                self.reportGen.setReportInfo(target=self.target_url,
                                             scope=self.target_scope,
                                             date_string=time.strftime("%a, %d %b %Y %H:%M:%S +0000", time.gmtime()),
                                             version=WAPITI_VERSION)
                break

        vulnXMLParser = VulnerabilityXMLParser()
        vulnXMLParser.parse(os.path.join(CONF_DIR, "config", "vulnerabilities", "vulnerabilities.xml"))
        for vul in vulnXMLParser.getVulnerabilities():
            self.reportGen.addVulnerabilityType(_(vul.getName()),
                                                _(vul.getDescription()),
                                                _(vul.getSolution()),
                                                vul.getReferences())

        anomXMLParser = AnomalyXMLParser()
        anomXMLParser.parse(os.path.join(CONF_DIR, "config", "vulnerabilities", "anomalies.xml"))
        for anomaly in anomXMLParser.getAnomalies():
            self.reportGen.addAnomalyType(_(anomaly.getName()),
                                          (anomaly.getDescription()),
                                          _(anomaly.getSolution()),
                                          anomaly.getReferences())

    def __initAttacks(self):
        self.__initReport()

        from wapitiCore.attack import attack

        print(_("[*] Loading modules:"))
        print(u"\t {0}".format(u", ".join(attack.modules)))
        for mod_name in attack.modules:
            mod = __import__("wapitiCore.attack." + mod_name, fromlist=attack.modules)
            mod_instance = getattr(mod, mod_name)(self.http_engine, self.reportGen)
            if hasattr(mod_instance, "setTimeout"):
                mod_instance.setTimeout(self.http_engine.getTimeOut())
            self.attacks.append(mod_instance)

            self.attacks.sort(lambda a, b: a.PRIORITY - b.PRIORITY)

        for attack_module in self.attacks:
            attack_module.setVerbose(self.verbose)
            if self.color == 1:
                attack_module.setColor()

        if self.options != "":
            opts = self.options.split(",")

            for opt in opts:
                method = ""
                if opt.find(":") > 0:
                    module, method = opt.split(":", 1)
                else:
                    module = opt

                # desactivate some module options
                if module.startswith("-"):
                    module = module[1:]
                    if module == "all":
                        for attack_module in self.attacks:
                            if method == "get" or method == "":
                                attack_module.doGET = False
                            if method == "post" or method == "":
                                attack_module.doPOST = False
                    else:
                        found = False
                        for attack_module in self.attacks:
                            if attack_module.name == module:
                                found = True
                                if method == "get" or method == "":
                                    attack_module.doGET = False
                                if method == "post" or method == "":
                                    attack_module.doPOST = False
                        if not found:
                            print(_("[!] Unable to find a module named {0}").format(module))

                # activate some module options
                else:
                    if module.startswith("+"):
                        module = module[1:]
                    if module == "all":
                        for attack_module in self.attacks:
                            if method == "get" or method == "":
                                attack_module.doGET = True
                            if method == "post" or method == "":
                                attack_module.doPOST = True
                    else:
                        found = False
                        for attack_module in self.attacks:
                            if attack_module.name == module:
                                found = True
                                if method == "get" or method == "":
                                    attack_module.doGET = True
                                if method == "post" or method == "":
                                    attack_module.doPOST = True
                        if not found:
                            print(_("[!] Unable to find a module named {0}").format(module))

    def browse(self, crawlerFile):
        "Extract hyperlinks and forms from the webpages found on the website"
        #self.urls, self.forms = self.myls.go(crawlerFile)
        self.myls.go(crawlerFile)
        self.urls = self.myls.getLinks()
        self.forms = self.myls.getForms()

    def attack(self):
        "Launch the attacks based on the preferences set by the command line"
        if self.urls == {} and self.forms == []:
            print(_("No links or forms found in this page !"))
            print(_("Make sure the url is correct."))
            sys.exit(1)

        self.__initAttacks()

        for x in self.attacks:
            if x.doGET is False and x.doPOST is False:
                continue
            print('')
            if x.require != []:
                t = [y.name for y in self.attacks if y.name in x.require and (y.doGET or y.doPOST)]
                if x.require != t:
                    print(_("[!] Missing dependecies for module {0}:").format(x.name))
                    print(u"  {0}".format(",".join([y for y in x.require if y not in t])))
                    continue
                else:
                    x.loadRequire([y for y in self.attacks if y.name in x.require])

            x.logG(_("[+] Launching module {0}"), x.name)
            x.attack(self.urls, self.forms)

        if self.myls.getUploads() != []:
            print('')
            print(_("Upload scripts found:"))
            print("----------------------")
            for upload_form in self.myls.getUploads():
                print(upload_form)
        if not self.outputFile:
            if self.reportGeneratorType == "html":
                self.outputFile = self.COPY_REPORT_DIR
            else:
                if self.reportGeneratorType == "txt":
                    self.outputFile = self.REPORT_FILE + ".txt"
                else:
                    self.outputFile = self.REPORT_FILE + ".xml"
        self.reportGen.generateReport(self.outputFile)
        print('')
        print(_("Report"))
        print("------")
        print(_("A report has been generated in the file {0}").format(self.outputFile))
        if self.reportGeneratorType == "html":
            print(_("Open {0}/index.html with a browser to see this report.").format(self.outputFile))

    def setTimeOut(self, timeout=6.0):
        "Set the timeout for the time waiting for a HTTP response"
        self.http_engine.setTimeOut(timeout)

    def setVerifySsl(self, verify=True):
        "Set whether SSL must be verified."
        self.http_engine.setVerifySsl(verify)

    def setProxy(self, proxy=""):
        "Set a proxy to use for HTTP requests."
        self.http_engine.setProxy(proxy)

    def addStartURL(self, url):
        "Specify an URL to start the scan with. Can be called several times."
        self.myls.addStartURL(url)

    def addExcludedURL(self, url):
        "Specify an URL to exclude from the scan. Can be called several times."
        self.myls.addExcludedURL(url)

    def setCookieFile(self, cookie):
        "Load session data from a cookie file"
        self.http_engine.setCookieFile(cookie)

    def setAuthCredentials(self, auth_basic):
        "Set credentials to use if the website require an authentication."
        self.http_engine.setAuthCredentials(auth_basic)

    def setAuthMethod(self, auth_method):
        "Set the authentication method to use."
        self.http_engine.setAuthMethod(auth_method)

    def addBadParam(self, bad_param):
        """Exclude a parameter from an url (urls with this parameter will be
        modified. This function can be call several times"""
        self.myls.addBadParam(bad_param)

    def setNice(self, nice):
        """Define how many tuples of parameters / values must be sent for a
        given URL. Use it to prevent infinite loops."""
        self.myls.setNice(nice)

    def setScope(self, scope):
        """Set the scope of the crawler for the analysis of the web pages"""
        self.target_scope = scope
        self.myls.setScope(scope)

    def setColor(self):
        "Put colors in the console output (terminal must support colors)"
        self.color = 1

    def verbosity(self, vb):
        "Define the level of verbosity of the output."
        self.verbose = vb
        self.myls.verbosity(vb)

    def setModules(self, options=""):
        """Activate or desactivate (default) all attacks"""
        self.options = options

    def setReportGeneratorType(self, repGentype="xml"):
        "Set the format of the generated report. Can be xml, html of txt"
        self.reportGeneratorType = repGentype

    def setOutputFile(self, outputFile):
        "Set the filename where the report will be written"
        self.outputFile = outputFile

if __name__ == "__main__":
    doc = _("wapitiDoc")
    try:
        prox = ""
        auth = []
        crawlerPersister = CrawlerPersister()
        crawlerFile = None
        attackFile = None

        print(_("Wapiti-2.3.0 (wapiti.sourceforge.net)"))

        # Fix bor bug #31
        if sys.getdefaultencoding() != "utf-8":
            reload(sys)
            sys.setdefaultencoding("utf-8")

        import requests
        if requests.__version__.startswith("0."):
            print("Error: You have an outdated version of python-requests. Please upgrade")
            sys.exit(1)

        if len(sys.argv) < 2:
            print(doc)
            sys.exit(0)
        if '-h' in sys.argv or '--help' in sys.argv:
            print(doc)
            sys.exit(0)

        if not os.path.isdir(crawlerPersister.CRAWLER_DATA_DIR):
            os.makedirs(crawlerPersister.CRAWLER_DATA_DIR)

        url = sys.argv[1]
        wap = Wapiti(url)

        try:
            opts, args = getopt.getopt(sys.argv[2:],
                                       "hup:s:x:c:a:r:v:t:m:o:f:n:kib:",
                                       ["help", "color", "proxy=", "start=", "exclude=",
                                        "cookie=", "auth=", "remove=", "verbose=", "timeout=",
                                        "module=", "output=", "format=", "nice=",
                                        "attack", "continue", "scope=", "verify-ssl=", "auth-method="])
        except getopt.GetoptError, e:
            print(e)
            sys.exit(2)

        try:
            for o, a in opts:
                if o in ["-h", "--help"]:
                    print(doc)
                    sys.exit(0)
                if o in ["-s", "--start"]:
                    if a.startswith("http://") or a.startswith("https://"):
                        wap.addStartURL(a)
                    else:
                        raise InvalidOptionValue(o, a)
                if o in ["-x", "--exclude"]:
                    if a.startswith("http://") or a.startswith("https://"):
                        wap.addExcludedURL(a)
                    else:
                        raise InvalidOptionValue(o, a)
                if o in ["-p", "--proxy"]:
                    if a.startswith("http://") or a.startswith("https://"):
                        wap.setProxy(a)
                    else:
                        raise InvalidOptionValue(o, a)
                if o in ["-c", "--cookie"]:
                    if os.path.isfile(a):
                        wap.setCookieFile(a)
                    else:
                        raise InvalidOptionValue(o, a)
                if o in ["-a", "--auth"]:
                    if a.find("%") >= 0:
                        auth = a.split("%")
                        wap.setAuthCredentials(auth)
                    else:
                        raise InvalidOptionValue(o, a)
                if o in ["--auth-method"]:
                    if a in ["basic", "digest", "kerberos", "ntlm"]:
                        wap.setAuthMethod(a)
                    else:
                        raise InvalidOptionValue(o, a)
                if o in ["-r", "--remove"]:
                    wap.addBadParam(a)
                if o in ["-n", "--nice"]:
                    if str.isdigit(a):
                        wap.setNice(int(a))
                    else:
                        raise InvalidOptionValue(o, a)
                if o in ["-u", "--color"]:
                    wap.setColor()
                if o in ["-v", "--verbose"]:
                    if str.isdigit(a):
                        wap.verbosity(int(a))
                    else:
                        raise InvalidOptionValue(o, a)
                if o in ["-t", "--timeout"]:
                    if str.isdigit(a):
                        wap.setTimeOut(int(a))
                    else:
                        raise InvalidOptionValue(o, a)
                if o in ["-m", "--module"]:
                    wap.setModules(a)
                if o in ["-o", "--output"]:
                    wap.setOutputFile(a)
                if o in ["-f", "--format"]:
                    found_generator = False
                    for repGenInfo in wap.xmlRepGenParser.getReportGenerators():
                        if a == repGenInfo.getKey():
                            wap.setReportGeneratorType(a)
                            found_generator = True
                            break
                    if not found_generator:
                        raise InvalidOptionValue(o, a)
                if o in ["-b", "--scope"]:
                    if a in ["page", "folder", "domain"]:
                        wap.setScope(a)
                    else:
                        raise InvalidOptionValue(o, a)
                if o in ["-k", "--attack"]:
                    if a != "" and a[0] != '-':
                        attackFile = a
                    else:
                        hostname = url.split("://")[1].split("/")[0]
                        attackFile = u"{0}{1}{2}.xml".format(crawlerPersister.CRAWLER_DATA_DIR,
                                                             os.path.sep,
                                                             hostname)
                if o in ["-i", "--continue"]:
                    if a != '' and a[0] != '-':
                        crawlerFile = a
                    else:
                        hostname = url.split("://")[1].split("/")[0]
                        crawlerFile = u"{0}{1}{2}.xml".format(crawlerPersister.CRAWLER_DATA_DIR,
                                                              os.path.sep,
                                                              hostname)
                if o in ["--verify-ssl"]:
                    if str.isdigit(a):
                        wap.setVerifySsl(bool(int(a)))
                    else:
                        raise InvalidOptionValue(o, a)

        except InvalidOptionValue, msg:
            print(msg)
            sys.exit(2)

        if attackFile is not None:
            if crawlerPersister.isDataForUrl(attackFile) == 1:
                crawlerPersister.loadXML(attackFile)
                wap.urls = crawlerPersister.getBrowsed()
                wap.forms = crawlerPersister.getForms()
                print(_("File {0} loaded. Wapiti will use it to perform the attack").format(attackFile))
            else:
                print(_("File {0} not found. Wapiti will scan the web site again").format(attackFile))
                wap.browse(crawlerFile)
        else:
            wap.browse(crawlerFile)
        try:
            wap.attack()
        except KeyboardInterrupt:
            print('')
            print(_("Attack process interrupted. To perform again the attack, "
                    "lauch Wapiti with \"-i\" or \"-k\" parameter."))
            print('')
            pass
    except SystemExit:
        pass
