# search, a bzr plugin for searching within bzr branches/repositories.
# Copyright (C) 2011 Jelmer Vernooij
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 as published
# by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301 USA
#

"""Smart server integration for bzr-search."""

from bzrlib import remote
from bzrlib.bzrdir import BzrDir
from bzrlib.errors import (
    ErrorFromSmartServer,
    UnexpectedSmartServerResponse,
    )
from bzrlib.smart.branch import (
    SmartServerBranchRequest,
    )
from bzrlib.smart.request import SuccessfulSmartServerResponse

from bzrlib.plugins.search import errors, index


def _encode_termlist(termlist):
    return ["\0".join([k.encode('utf-8') for k in term]) for term in termlist]

def _decode_termlist(termlist):
    return [tuple([k.decode('utf-8') for k in term.split('\0')]) for term in termlist]


class RemoteIndex(object):
    """Index accessed over a smart server."""

    def __init__(self, client, path, branch=None):
        self._client = client
        self._path = path
        self._branch = branch

    def _call(self, method, *args, **err_context):
        try:
            return self._client.call(method, *args)
        except ErrorFromSmartServer, err:
            self._translate_error(err, **err_context)

    def _call_expecting_body(self, method, *args, **err_context):
        try:
            return self._client.call_expecting_body(method, *args)
        except ErrorFromSmartServer, err:
            self._translate_error(err, **err_context)

    def _call_with_body_bytes(self, method, args, body_bytes, **err_context):
        try:
            return self._client.call_with_body_bytes(method, args, body_bytes)
        except ErrorFromSmartServer, err:
            self._translate_error(err, **err_context)

    def _call_with_body_bytes_expecting_body(self, method, args, body_bytes,
                                             **err_context):
        try:
            return self._client.call_with_body_bytes_expecting_body(
                method, args, body_bytes)
        except errors.ErrorFromSmartServer, err:
            self._translate_error(err, **err_context)

    def _translate_error(self, err, **context):
        remote._translate_error(err, index=self, **context)

    @classmethod
    def open(cls, branch):
        # This might raise UnknownSmartMethod,
        # but the caller should handle that.
        response = branch._call("Branch.open_index",
            branch._remote_path())
        if response == ('no', ):
            raise errors.NoSearchIndex(branch.user_transport)
        if response != ('yes', ):
            raise UnexpectedSmartServerResponse(response)
        return RemoteIndex(branch._client, branch._remote_path(), branch)

    @classmethod
    def init(cls, branch):
        response = branch._call("Branch.init_index",
            branch._remote_path())
        if response != ('ok', ):
            raise errors.UnexpectedSmartServerResponse(response)
        return RemoteIndex(branch._client, branch._remote_path(), branch)

    def index_branch(self, branch, tip_revision):
        """Index revisions from a branch.

        :param branch: The branch to index.
        :param tip_revision: The tip of the branch.
        """
        self.index_revisions(branch, [tip_revision])

    def index_revisions(self, branch, revisions_to_index):
        """Index some revisions from branch.

        :param branch: A branch to index.
        :param revisions_to_index: A set of revision ids to index.
        """
        body = "\n".join(revisions_to_index)
        response = self._call_with_body_bytes(
            'Index.index_revisions', (self._path, branch._remote_path(),),
            body)
        if response != ('ok', ):
            raise errors.UnexpectedSmartServerResponse(response)

    def indexed_revisions(self):
        """Return the revision_keys that this index contains terms for."""
        response, handler = self._call_expecting_body(
            'Index.indexed_revisions', self._path)
        if response != ('ok', ):
            raise errors.UnexpectedSmartServerResponse(response)
        byte_stream = handler.read_streamed_body()
        data = ""
        for bytes in byte_stream:
            data += bytes
            lines = data.split("\n")
            data = lines.pop()
            for revid in lines:
                yield (revid, )

    def search(self, termlist):
        """Trivial set-based search of the index.

        :param termlist: A list of terms.
        :return: An iterator of SearchResults for documents indexed by all
            terms in the termlist.
        """
        index._ensure_regexes()
        response, handler = self._call_expecting_body('Index.search',
            self._path, _encode_termlist(termlist))
        if response != ('ok', ):
            raise errors.UnexpectedSmartServerResponse(response)
        byte_stream = handler.read_streamed_body()
        data = ""
        ret = []
        for bytes in byte_stream:
            data += bytes
            lines = data.split("\n")
            data = lines.pop()
            for l in lines:
                if l[0] == 'r':
                    hit = index.RevisionHit(self._branch.repository, (l[1:], ))
                elif l[0] == 't':
                    hit = index.FileTextHit(self, self._branch.repository,
                        tuple(l[1:].split("\0")), termlist)
                elif l[0] == 'p':
                    hit = index.PathHit(l[1:])
                else:
                    raise AssertionError("Unknown hit kind %r" % l[0])
            # We can't yield, since the caller might try to look up results
            # over the same medium.
            ret.append(hit)
        return iter(ret)

    def suggest(self, termlist):
        """Generate suggestions for extending a search.

        :param termlist: A list of terms.
        :return: An iterator of terms that start with the last search term in
            termlist, and match the rest of the search.
        """
        response = self._call('Index.suggest',
            self._path, _encode_termlist(termlist))
        if response[0] != 'ok':
            raise UnexpectedSmartServerResponse(response)
        return [(suggestion.decode('utf-8'),) for suggestion in response[1]]


class SmartServerBranchRequestOpenIndex(SmartServerBranchRequest):
    """Open an index file."""

    def do_with_branch(self, branch):
        """open an index."""
        try:
            idx = index.open_index_branch(branch)
        except errors.NoSearchIndex:
            return SuccessfulSmartServerResponse(('no', ))
        else:
            return SuccessfulSmartServerResponse(('yes', ))


class SmartServerBranchRequestInitIndex(SmartServerBranchRequest):
    """Create an index."""

    def do_with_branch(self, branch, format=None):
        """Create an index."""
        if format is None:
            idx = index.init_index(branch)
        else:
            idx = index.init_index(branch, format)
        return SuccessfulSmartServerResponse(('ok', ))


class SmartServerIndexRequest(SmartServerBranchRequest):
    """Base class for index requests."""

    def do_with_branch(self, branch, *args):
        idx = index.open_index_branch(branch)
        return self.do_with_index(idx, *args)

    def do_with_index(self, index, *args):
        raise NotImplementedError(self.do_with_index)


class SmartServerIndexRequestIndexRevisions(SmartServerIndexRequest):
    """Index a set of revisions."""

    def do_body(self, body_bytes):
        revids = body_bytes.split("\n")
        self._index.index_revisions(self._branch, revids)
        return SuccessfulSmartServerResponse(('ok', ))

    def do_with_index(self, index, branch_path):
        self._index = index
        transport = self.transport_from_client_path(branch_path)
        controldir = BzrDir.open_from_transport(transport)
        if controldir.get_branch_reference() is not None:
            raise errors.NotBranchError(transport.base)
        self._branch = controldir.open_branch(ignore_fallbacks=True)
        # Indicate we want a body
        return None


class SmartServerIndexRequestIndexedRevisions(SmartServerIndexRequest):
    """Retrieve the set of revisions in the index."""

    def body_stream(self, index):
        for revid in index.indexed_revisions():
            yield "%s\n" % "\0".join(revid)

    def do_with_index(self, index):
        return SuccessfulSmartServerResponse(('ok', ),
            body_stream=self.body_stream(index))


class SmartServerIndexRequestSuggest(SmartServerIndexRequest):
    """Suggest alternative terms."""

    def do_with_index(self, index, termlist):
        suggestions = index.suggest(_decode_termlist(termlist))
        return SuccessfulSmartServerResponse(
            ('ok',
                [suggestion.encode('utf-8') for (suggestion,) in suggestions]))


class SmartServerIndexRequestSearch(SmartServerIndexRequest):
    """Search for terms."""

    def body_stream(self, results):
        for hit in results:
            if isinstance(hit, index.FileTextHit):
                yield "t%s\0%s\n" % hit.text_key
            elif isinstance(hit, index.RevisionHit):
                yield "r%s\n" % hit.revision_key[0]
            elif isinstance(hit, index.PathHit):
                yield "p%s\n" % hit.path_utf8
            else:
                raise AssertionError("Unknown hit type %r" % hit)

    def do_with_index(self, index, termlist):
        results = index.search(_decode_termlist(termlist))
        return SuccessfulSmartServerResponse(
            ('ok',), body_stream=self.body_stream(results))
