from functools import wraps
from math import log as ln
from types import FunctionType
from typing import Iterable, List, Dict, Callable

from websites_test_framework import files_extensions
from websites_test_framework.custom_types import (
    TestInfo,
    TestReturn,
    TestName,
    TestMethod,
    TestsResults,
    PathLike,
)
from websites_test_framework.files_extensions import PIC_FILES
from websites_test_framework.param import DEFAULT_SUCCESS_MESSAGE, SCORES_NDIGITS
from websites_test_framework.tools import (
    as_path,
    warn,
    cached_property,
    is_valid_xml,
    is_url_relative,
    url_exists,
)
from websites_test_framework.website import Website


def test(title: str = None, weight: float = 1, relative=False):
    """Function generating a decorator, to mark a method as a test.

    Tests must return a float between 0 and 1 and a list of logs (strings).

    If `relative` is True, test may return a result above 1.
    The class CollectTestResults will then divide the score of each
    website by the best score.
    """

    def mark_function_as_test(f):
        @wraps(f)
        def new_f(self, *args, **kw) -> TestReturn:
            result: TestReturn = f(self, *args, **kw)
            # noinspection PyUnboundLocalVariable
            if (
                isinstance(result, tuple)
                and len(result) == 2
                and isinstance((score := result[0]), (float, int))
                and (0 <= (score := round(score, SCORES_NDIGITS)))
                and (relative or score <= 1)
                and isinstance((log := result[1]), list)
                and all(isinstance(elt, str) for elt in log)
            ):
                if not log and score == 1:
                    log.append(DEFAULT_SUCCESS_MESSAGE)
                return score, log
            raise ValueError(
                f"Invalid return values for website test {f.__qualname__!r}: {result!r}\n"
                f"A test must return (success_rate, log) where "
                "0 <= success_rate <= 1 and log is a list of strings."
            )

        new_f.is_test = True
        new_f.title = title
        new_f.weight = weight
        new_f.is_relative = relative
        return new_f

    return mark_function_as_test


class WebsiteTest:
    """Base class for testing a website.

    Main method is run(), which return a dictionary with all tests' results.

    To write a battery of tests, subclass this class, and add custom tests as methods.

    Each method decorated with @test is a test.

    Each test must return:
        - the result: a float between 0 and 1,
        - the log: a string.

    One should also overwrite `.get_authors()` method, which must return the lists
    of the authors of the website.
    """

    __test__ = False  # For pytest

    def __init__(
        self,
        website_path: PathLike,
        skipped_tests: Iterable[str] = tuple(),
        path_on_server: str = "/",
    ):
        self.path = as_path(website_path)
        self.website = Website(self.path, path_on_server)
        self.skipped_tests = skipped_tests

        if len(self.website.html_files) == 0:
            warn("No HTML file", str(self.path))

    def __str__(self):
        return f"{self.__class__.__name__}('{self.path}')"

    # ---------------------------------------------
    # This method must be implemented by subclasses
    # ---------------------------------------------

    def get_authors(self) -> List[str]:
        """WebsiteTest subclasses should implement this method.

        It must return the list of the authors of the website.
        """
        raise NotImplementedError

    # -----------------------------------------
    # Methods providing information about tests
    # -----------------------------------------

    @cached_property
    def tests_names(self) -> List[TestName]:
        """Return active tests' names.

        Tests' name included in self.skipped_tests are not returned.
        """
        return [
            name for name in self.__class__.get_all_tests_names() if name not in self.skipped_tests
        ]

    @classmethod
    def get_all_tests_names(cls) -> List[TestName]:
        """Return sorted tests' names list, where name is method's name.

        Tests are sorted by name.

        All tests' names are returned, even those included in self.skipped_tests.
        This is a low-level method, usage of `.tests_names` property is recommended
        instead.
        """
        return sorted(
            TestName(name)
            for name in dir(cls)
            if isinstance(attribute := getattr(cls, name), FunctionType)
            and getattr(attribute, "is_test", False)
        )

    @classmethod
    def _test(cls, name: TestName) -> TestMethod:
        return getattr(cls, name)

    @classmethod
    def get_tests_infos(cls) -> Dict[TestName, TestInfo]:
        """Get title and weight for all tests.

        Tests are sorted by (method) name.
        """
        return {
            name: {
                "weight": cls._test(name).weight,
                "title": cls._test(name).title,
                "is_relative": cls._test(name).is_relative,
            }
            for name in cls.get_all_tests_names()
        }

    # ------------------------------------------
    # Main method: should be called to run tests
    # ------------------------------------------

    def run(self, filter_func: Callable[[TestName], bool] = None) -> TestsResults:
        """Main method: Run all tests and return results as a dict.

        Tests' results are sorted by test's name.

        If `filter_func` is not `None`, it should be a function taking a test's name as argument
        and returning a boolean value.
        """
        print(f"\n\033[1mTesting {self.path}...\033[0m")
        all_results: TestsResults = {}
        for name in filter(filter_func, self.tests_names):
            result: TestReturn = getattr(self, name)()
            score, log = result
            all_results[name] = {"score": score, "log": "\n".join(log)}
        return all_results

    # ------------------------------------------------
    # Tests: all tests must start with @test decorator
    # ------------------------------------------------

    @test(title="UTF8 ?", weight=1)
    def test_utf8(self):
        files_number = 0
        score = 0
        log = []
        for tested_file in self.website.web_files:
            files_number += 1
            if tested_file.is_utf8:
                score += 1
            else:
                log.append(
                    f"{tested_file.relative_path} doit être encodé en UTF8 "
                    f"({tested_file.encoding} détecté)."
                )
        if files_number == 0:
            return 0, ["No HTML/CSS files !"]
        return 0.8 * score / files_number + (0.2 if score == files_number else 0), log

    @test(title="charset", weight=0.5)
    def test_declared_encoding(self):
        score = 0
        log = []
        if not len(self.website.html_files):
            return 0, ["No HTML file !"]
        for html_file in self.website.html_files:
            metas = html_file.structure.find_all("meta", charset=True)
            if len(metas) == 1:
                score += 1
                if metas[0].attrs["charset"].lower().strip().replace("-", "") == "utf8":
                    score += 1
                else:
                    log.append(f"Bad encoding declaration: {html_file.relative_path}")
            else:
                log.append(f"No encoding declaration: {html_file.relative_path}")
        return score / (2 * len(self.website.html_files)), log

    @test(title="XML ?", weight=0.5)
    def test_xml(self):
        log = []
        counter = 0
        files = self.website.html_files
        if len(files) == 0:
            return 0, ["No HTML file !"]
        for html_file in files:
            valid, msg = is_valid_xml(html_file.path)
            if valid:
                counter += 1
            if msg:
                log.append(f"File {html_file.relative_path}: {msg}")
        score = 0.9 * counter / len(files)
        if counter == len(files):  # All XML, well done !
            score += 0.1
        return sum(html_file.is_valid_xml for html_file in files) / len(files), log

    @staticmethod
    def _test_comments(
        files, minimal_score_if_any_comment=0.5, target_ratio: float = 0.01
    ) -> TestReturn:
        """Test if files are enough commented.

        `target_ratio` must be a float between 0 and 1.
        Maximal score is obtained when comment ratio >= target_ratio.
        """
        if not (0 <= target_ratio < 1):
            raise ValueError("`target_ratio` value must be between 0 and 1.")
        comments: List[str] = []
        score = 0.0
        for file in files:
            comments.extend(file.comments)
        log = [f"Nombre de commentaires: {len(comments)}"]

        if len(comments) >= 1:
            # At least one comment per file on average ?
            score += 0.4 * min(len(comments) / len(files), 1)

            # Ratio comments characters / total characters (max score if ratio is 1% or more)
            number_of_lines = sum(file.number_of_lines for file in files)
            score += 0.3 * min(len(comments) / (target_ratio * number_of_lines), 1)

            # Ratio number of comments / number of lines  (max score if ratio is 1% or more)
            total_comments_characters = sum(len(comment) for comment in comments)
            total_characters = sum(file.number_of_characters for file in files)
            score += 0.3 * min(total_comments_characters / (target_ratio * total_characters), 1)
            assert 0 <= score <= 1, score
            score = minimal_score_if_any_comment + (1 - minimal_score_if_any_comment) * score

        return score, log

    @test(title="HTML comments", weight=0.5)
    def test_html_comments(self) -> TestReturn:
        return self._test_comments(
            self.website.html_files, minimal_score_if_any_comment=0.5, target_ratio=0.01
        )

    @test(title="CSS comments", weight=1)
    def test_commentaires_css(self) -> TestReturn:
        return self._test_comments(
            self.website.used_css_files, minimal_score_if_any_comment=0.25, target_ratio=0.02
        )

    @test(title="HTML Validity", weight=4)
    def test_html_validity(self) -> TestReturn:
        log: List[str] = []
        penalty: float = 0
        for file in self.website.html_files:
            log.extend(file.errors)
            log.extend(file.warnings)
            penalty += len(file.errors) + 0.25 * len(file.warnings)
        if len(self.website.html_files) == 0:
            return 0, ["No HTML file found !"]
        # Max score of 0.5 if any error or warning is reported
        score = 0.5 - 0.8 * penalty / (
            len(self.website.html_files) + 0.01 * self.website.tags_global_number
        )
        if penalty == 0:
            return 1, ["HTML: OK"]
        return max(0.0, score), log

    @test(title="CSS Validity", weight=3)
    def test_css_validity(self) -> TestReturn:
        log: List[str] = []
        penalty = 0
        for file in self.website.css_files:
            n = file.errors
            if n > 0:
                log.append(f"File {file.name}: {n} error(s) found.")
            penalty += n
        if len(self.website.css_files) == 0:
            return 0, ["No CSS file found !"]
        # Max score of 0.5 if any error is reported
        score = 0.5 - 0.8 * penalty / (
            len(self.website.css_files) + 0.01 * self.website.css_rules_global_number
        )
        if penalty == 0:
            return 1, ["CSS: OK"]
        return max(0.0, score), log

    @test(title="HTML diversity", weight=1, relative=True)
    def test_html_diversity(self) -> TestReturn:
        log = []
        tags_count = self.website.tags_count
        total_tags = self.website.tags_global_number
        # Using <br/> is usually not a good practice.
        total_different_tags = sum(1 for key in tags_count if key != "br")
        log.append(f"Total tags: {total_tags}")
        log.append(f"Total different tags: {total_different_tags}")
        if total_tags == 0:
            return 0, log
        # 30 different tags is enough to have maximal score.
        score = min(1.0, max(10 * total_different_tags / total_tags, total_different_tags / 30))
        return score, log

    @test(title="div & span", weight=1, relative=True)
    def test_not_to_much_div_and_span(self) -> TestReturn:
        tags_count = self.website.tags_count
        div_span_count = tags_count["div"] + tags_count["span"]
        log = [f"Generic tags: {div_span_count}"]
        total_tags = sum(tags_count.values())
        if div_span_count == 0:
            return 1 if total_tags != 0 else 0, log
        return min(1.0, 0.1 * total_tags / div_span_count), log

    @test(title="HTML size", weight=0.5, relative=True)
    def test_number_of_tags(self):
        """Return the total number of tags, excluding div and span."""
        tags_count = self.website.tags_count
        total_tags = sum(tags_count.values()) - tags_count["div"] - tags_count["span"]
        return total_tags, [f"Tags number: {total_tags}"]

    @test(title="@rules", weight=0.5)
    def test_for_at_rules(self) -> TestReturn:
        """Test for @rules. Score is maximal if @rules number >= 3."""
        n = sum(len(file.at_rules_names) for file in self.website.css_files)
        return min(1.0, n / 3), [f"Number of @rules: {n}"]

    @test(title="@media", weight=0.5)
    def test_for_media_rule(self) -> TestReturn:
        return (
            any("@media" in file.at_rules_names for file in self.website.css_files),
            [],
        )

    @test(title="@font-face", weight=0.15)
    def test_for_font_face_rule(self) -> TestReturn:
        found = any("@font-face" in file.at_rules_names for file in self.website.css_files)
        return found, ["Used." if found else "Not used."]

    @test(title="@keyframes", weight=0.3)
    def test_for_keyframes_rule(self) -> TestReturn:
        found = any("@keyframes" in file.at_rules_names for file in self.website.css_files)
        return found, ["Used." if found else "Not used."]

    @test(title="No JS", weight=1)
    def test_no_javascript(self) -> TestReturn:
        for html_file in self.website.html_files:
            if html_file.tags_count["script"] != 0:
                return 0, [f"JS in {html_file.name}"]
        return 1 if self.website.html_files else 0, []

    @test(title="advanced selectors", weight=2)
    def test_for_advanced_selectors(self) -> TestReturn:
        score = 0.0
        log = []
        for symbol in "<+~:[":
            n = sum(
                symbol in selector for file in self.website.css_files for selector in file.selectors
            )
            log.append(f"Selector {symbol} : {n}")
            score += ln(1 + n)
        n = sum(len(file.selectors) for file in self.website.css_files)
        log.append(f"Total: {n}")

        return min(score / ln(1 + n), 1.0) if n != 0 else 0, log

    @test(title="index.html", weight=1)
    def test_for_index(self) -> TestReturn:
        if not (self.path / "index.html").is_file():
            return 0, ["No index.html file !"]
        good = bad = 0
        log = []
        for path in self.path.glob("**"):
            assert path.is_dir()
            if any(pth.suffix == ".html" for pth in path.glob("*")):
                # This folder contains html files, it should have an `index.html` file.
                if not (path / "index.html").is_file():
                    bad += 1
                    log.append(f"No index.html file in {path}.")
                else:
                    good += 1
        if bad:
            return 0.2 + 0.8 * good / (good + bad), log
        return 1, log

    @test(title="File types", weight=1)
    def test_files_types(self) -> TestReturn:
        good = bad = 0
        log = []
        for path in self.website.iterate_over_files_paths():
            ext = path.suffix.lower()
            if ext in files_extensions.GOOD_FILES:
                good += 1
            else:
                bad += 1
                log.append(f"Bad file type: {ext!r} ({path.name!r})")
                if ext not in files_extensions.BAD_FILES:
                    warn(
                        f"Unreferenced file type: {ext!r}",
                        "Update `websites_test_framework.files_extensions`",
                    )
        if bad + good == 0:
            return 0, ["No file !"]
        return good / (2 * bad + good), log

    @test(title="img folder", weight=0.5)
    def test_img_folder(self) -> TestReturn:
        good = bad = 0
        log = []
        for path in self.website.iterate_over_files_paths(*PIC_FILES):
            if path.parent.name.lower().rstrip("s") in ("img", "pic", "picture", "photo", "image"):
                good += 1
            else:
                bad += 1
                log.append(f"Parent image directory: {path.parent.name}")
        if bad + good == 0:
            return (0.8 if self.website.html_files else 0), ["No image file."]
        return good / (bad + good), log

    @test(title="css folder", weight=0.5)
    def test_css_folder(self) -> TestReturn:
        good = bad = 0
        log = []
        for path in self.website.iterate_over_files_paths(".css"):
            if path.parent.name.lower() in ("css", "style", "styles", "stylesheet", "stylesheets"):
                good += 1
            else:
                bad += 1
                log.append(f"Parent css directory: {path.parent.name}")
        if bad + good == 0:
            return 0, ["No CSS file !"]
        return good / (bad + good), log

    @test(title="Unused css", weight=0.5)
    def test_unused_css(self) -> TestReturn:
        if not self.website.used_css_files:
            return 0, ["Not any used CSS files"]
        used = len(self.website.used_css_files)
        total = len(self.website.css_files)
        if used == total:
            return 1, []
        assert (
            used < total
        ), f"Path: {self.website.root}\nUsed:{[file.path for file in self.website.used_css_files]}"
        return 0.5 * used / total, [f"{total - used} unused CSS files"]

    def _test_url(self, file, tag):
        log = []
        penalty = 0
        try:
            url = tag.attrs["href"].strip()
            try:
                url, fragment = url.split("#", 1)
            except ValueError:
                fragment = ""
            if url:
                if is_url_relative(url):
                    if not (file.path.parent / url).resolve().exists():
                        penalty += 1

                elif url.startswith("/"):
                    url = self.website.rewrite_absolute_url(url)
                    if not (self.website.root / url.lstrip("/")).resolve().exists():
                        penalty += 1
                elif url.startswith("http:") or url.startswith("https:"):
                    exist = url_exists(url)
                    if exist is None:
                        print(
                            msg := (
                                f"(Can't test {url} from {file.path}: "
                                "server down or access forbidden, skipping.)"
                            )
                        )
                        log.append(msg)
                    elif not exist:
                        penalty += 0.2
                elif url.startswith("file:"):
                    penalty += 1
                else:
                    pass  # TODO: tests for others URL ? Like mailto:...
                if penalty > 0:
                    print(msg := f"Broken link: {url} from {file.path}")
                    log.append(msg)
                else:
                    log.append(f"Link OK: {url}")
            elif fragment:
                assert not url  # testing for other_file.html#anchor is not implemented yet...
                # Bad inner link.
                n = len(file.structure.html.find_all(id=fragment))
                if n == 0:
                    print(msg := f"Id {fragment!r} not found : file {file.path}")
                    penalty += 0.5
                elif n == 1:
                    msg = f"Fragment OK: #{fragment}"
                else:
                    print(msg := f"Id {fragment!r} is not unique: file {file.path}")
                    penalty += 0.5
                log.append(msg)
        except KeyError:
            penalty += 1  # no href in <a> !
            log.append(f"Invalid <a> tag: {tag}.")
        return penalty, log

    @test(title="Broken css links", weight=0.5)
    def test_broken_css_references(self) -> TestReturn:
        total = 0
        bad = 0
        log: List[str] = []
        for file in self.website.html_files:
            if file.structure.head is None:
                bad += 1
                total += 1
                print(msg := f"Bad HTML file: {file.path}")
                log.append(msg)
            else:
                for link in file.structure.head.find_all("link"):
                    if link.attrs.get("rel") == ["stylesheet"]:
                        total += 1
                        penalty, log_ = self._test_url(file, link)
                        bad += penalty
                        log.extend(log_)
        for file in self.website.css_files:
            try:
                for url in file.structure.list_imports():
                    if is_url_relative(url):
                        total += 1
                        if not (file.path.parent / url).resolve().is_file():
                            bad += 1
                            print(msg := f"Broken link: {url}")
                            log.append(msg)
                        elif url.startswith("/"):
                            pass  # TODO: test absolute paths too ?
            except LookupError:
                total += 1
                bad += 1  # binary file !
                log.append(f"Bad CSS file: {file.path}")
        if total == 0:
            return 0, ["No CSS file linked."]
        return max(0, 1 if bad == 0 else 0.9 * (1 - bad / total)), log

    @test(title="Broken urls", weight=0.5)
    def test_broken_url_in_html(self) -> TestReturn:
        # TODO: search for broken img links too.
        total = 0
        bad = 0
        log: List[str] = []
        for file in self.website.html_files:
            if file.structure.html is None:
                bad += 1
                total += 1
                print(msg := f"Bad HTML file: {file.path}")
                log.append(msg)
            else:
                for a in file.structure.html.find_all("a"):
                    total += 1
                    penalty, log_ = self._test_url(file, a)
                    bad += penalty
                    log.extend(log_)
        if total == 0:
            return 0, ["No link found."]
        return max(0, 1 if bad == 0 else 0.7 * (1 - bad / total)), log
