#!/usr/bin/env python3
import gitlab
import yaml
import os
import sys
import fnmatch
import pprint
import argparse
import dataclasses
import collections
import collections.abc
import concurrent.futures

def thread_pool(num_workers, func, items, num_retries=0):
  def func_with_retries(item):
    for i in range(num_retries + 1):
      try:
        func(item)
      except:
        itemid = getattr(item, 'path_with_namespace', None)
        if itemid == None:
          itemid = getattr(item, 'id', None)
        if itemid == None:
          itemid = item
        print(f"ERROR on item {itemid}, try #{i}", file=sys.stderr)
        sys.excepthook(*sys.exc_info())
      else:
        break

  with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
    futures = [executor.submit(func_with_retries, item) for item in items]
  results = [f.result() for f in futures]
  return results

def flatten(items):
  for x in items:
    if isinstance(x, collections.abc.Iterable) and not isinstance(x, (str, bytes)):
      yield from flatten(x)
    else:
      yield x

def is_glob(pattern):
  return any(i in pattern for i in ['*', '?', '['])

def asdict(d):
  if dataclasses.is_dataclass(d):
    return dataclasses.asdict(d)
  if isinstance(d, gitlab.base.RESTObject):
    r = dict(d.attributes)
    # skip some attributes that we don't want to print out
    r.pop('id', None)
    r.pop('created_at', None)
    r.pop('updated_at', None)
    r.pop('owner', None)
    r.pop('project_id', None)
    # special cases some asymmetric attributes
    merge_access_levels = r.pop('merge_access_levels', None)
    if merge_access_levels:
      r['merge_access_level'] = merge_access_levels[0]['access_level']
    push_access_levels = r.pop('push_access_levels', None)
    if push_access_levels:
      r['push_access_level'] = push_access_levels[0]['access_level']
    return r
  return d

def get_changes(src, dst):
  changes = {}
  src_dict = asdict(src)
  dst_dict = asdict(dst)
  for k, new_val in dst_dict.items():
    attributes_suffix = '_attributes'
    k_src = k
    if k_src.endswith(attributes_suffix):
        k_src = k_src[:-len(attributes_suffix)]
    old_val = src_dict.get(k_src)
    if isinstance(old_val, dict):
        del old_val['next_run_at'] # computed key
    if isinstance(new_val, dict):
        if 'name_regex_delete' in new_val:
            # convert to the compatibility name used by the read-side API
            new_val['name_regex'] = new_val.pop('name_regex_delete')
    if old_val != new_val:
      changes[k] = new_val
  return changes

SplittedSets = collections.namedtuple('SplittedSets', ['unchanged', 'add', 'drop', 'change'])

def split_sets_by_action(current, target, keyfunc):
  target_dict = {keyfunc(b): b for b in target }
  current_keys = set(keyfunc(b) for b in current)
  target_keys = set(target_dict.keys())

  to_keep = [b for b in current if keyfunc(b) in current_keys & target_keys]
  to_add = [b for b in target if keyfunc(b) in target_keys - current_keys]
  to_drop = [b for b in current if keyfunc(b) in current_keys - target_keys]

  unchanged = []
  to_change = []
  for src in to_keep:
    dst = target_dict[keyfunc(src)]
    if get_changes(src, dst):
      to_change.append(src)
    else:
      unchanged.append(src)

  return SplittedSets(unchanged, to_add, to_drop, to_change)

def apply_changes(collection, obj, target, actiondesc):
  actions = []
  changes = get_changes(obj, target)

  if isinstance(obj, gitlab.mixins.SaveMixin):
    for k, v in changes.items():
        actions.append(
          Action(
            msg=f'{actiondesc}: set {k} to {v}',
            func=lambda obj=obj, k=k, v=v: setattr(obj, k, v),
          )
        )
    if actions:
      actions.append(
        Action(
          msg=f'{actiondesc}: save changes',
          func=obj.save,
        )
      )
  else:
    actions += [
      Action(
        msg=f'{actiondesc}: clean',
        func=obj.delete,
      ),
      Action(
        msg=f'{actiondesc}: recreate',
        func=collection.create,
        params=(asdict(target),),
      )
    ]
  return actions

def collection_apply(collection, target, keyfunc, actiondesc):
  unchanged, to_add, to_drop, to_change = split_sets_by_action(collection.list(), target, keyfunc)
  ret = []
  for b in sorted(to_add, key=keyfunc):
    ret.append(
      Action(
        msg=f'{actiondesc}: add "{keyfunc(b)}"',
        func=collection.create,
        params=(asdict(b),),
    ))
  for b in sorted(to_drop, key=keyfunc):
    ret.append(
      Action(
        msg=f'{actiondesc}: drop "{keyfunc(b)}"',
        func=b.delete,
    ))
  for b in sorted(to_change, key=keyfunc):
    changes = next(i for i in target if keyfunc(i) == keyfunc(b))
    ret += apply_changes(collection, b, changes, f'{actiondesc}: edit "{keyfunc(b)}"')
  return ret

@dataclasses.dataclass(repr=False, eq=False, frozen=False)
class Action:
  msg: str
  func: callable = None
  params: tuple = ()
  ret: object = None

  def invoke(self):
    try:
      self.ret = self.func(*self.params)
    except Exception as e:
      self.ret = e

  def is_error(self):
    return isinstance(self.ret, Exception)

  def __str__(self):
    s = self.msg
    if self.is_error():
      s += ' -> error'
    return s

  def __repr__(self):
    s = str(self)
    return '{}<{}>'.format(type(self).__name__, s)

class Error(Action):
  def invoke(self):
    self.ret = Exception(self.msg)

  def is_error(self):
    return True

class Rule:
  def __init__(self, rulez):
    self.rulez = rulez
    self.rule = {}

  def prepare(self, project, rule):
    self.project = project
    self.rule = rule

  def compute(self):
    actions = []
    for k, v in self.rule.items():
      method_name = 'rule_' + k
      method = getattr(self, method_name, None)
      if method:
        if type(v) is list:
          actions += method(*v)
        elif type(v) is dict:
          actions += method(**v)
        else:
          actions += method(v)
    return actions

  def rule_ensure_branches(self, *items):
    branches = self.project.branches.list(per_page=100)
    actions = []
    for i in items:
      name = i['name']
      base = i['base']
      required = i.get('required', False)

      branch = next((b for b in branches if b.name == name), None)
      if branch:
        continue

      basebranch = next((b for b in branches if b.name == base), None)
      if basebranch:
        actions.append(
          Action(
            msg='project {}: ensure_branch: branch {} from {}'.format(self.project.path_with_namespace, name, base),
            func=self.project.branches.create,
            params=({'branch': name, 'ref': base},),
          )
        )
      elif required:
        actions.append(
          Error(
            msg='project {}: ensure_branch: base ref "{}" for branch "{}" not found'.format(self.project.path_with_namespace, base, name, base),
          )
        )
    return actions

  def rule_protected_branches(self, *branches):
    actions = collection_apply(self.project.protectedbranches, branches, lambda b: asdict(b)['name'], f'project {self.project.path_with_namespace}: protected_branches')
    return actions

  def rule_settings(self, **settings):
    actions = apply_changes(None, self.project, settings, f"project {self.project.path_with_namespace}: settings")
    return actions

  def rule_default_branch(self, branch):
    return self.rule_settings(default_branch=branch)

  def rule_schedules(self, *schedules):
    actions = collection_apply(self.project.pipelineschedules, schedules, lambda s: asdict(s)['description'], f'project {self.project.path_with_namespace}: schedules')
    return actions

class GitlabRulez:
  def __init__(self, actually_apply=False):
    self.actually_apply = actually_apply
    self.rulez = {}
    self.gl = None
    self.actions = []
    self.executed = []

  def connect(self, gitlab_instance, gitlab_server_url, gitlab_api_token):
    if gitlab_server_url:
      self.gl = gitlab.Gitlab(gitlab_server_url, private_token=gitlab_api_token)
    else:
      self.gl = gitlab.Gitlab.from_config(gitlab_instance)
    self.gl.auth()

  def _convert_constant_value(self, item, key):
    if key not in item:
      return
    symbol = item[key]
    if symbol == 'NO_ACCESS':
      item[key] = 0
    else:
      item[key] = getattr(gitlab, symbol)

  def load_rules(self, config_path):
    with open(config_path) as f:
      self.rulez = yaml.safe_load(f)

    for rule in self.rulez['rules']:
      for b in rule.get('protected_branches', []):
        self._convert_constant_value(b, 'merge_access_level')
        self._convert_constant_value(b, 'push_access_level')

  def _project_match_filter(self, project, filterglob):
    return fnmatch.fnmatch(project.path_with_namespace, filterglob)

  def _project_find_matching_rules(self, project):
    rules = []
    for rule in self.rulez['rules']:
      excludes = (self._project_match_filter(project, i) for i in flatten(rule.get('excludes', [])))
      if any(excludes):
        continue
      matches = (self._project_match_filter(project, i) for i in flatten(rule['matches']))
      if any(matches):
        rules.append(rule)
    return rules

  def _project_compute_rule(self, project, rule):
    r = Rule(self)
    r.prepare(project, rule)
    return r.compute()

  def _project_compute_rules(self, project, rules):
    print(project.id, project.path_with_namespace)
    if not rules:
      return []
    actions = []
    for rule in rules:
      actions += self._project_compute_rule(project, rule)
    return actions

  def apply_rules(self, filterglob=None, archived=False):
    projects = self.fetch_projects(filterglob, archived=archived)

    actions = []
    num_worker_threads = 10
    def compute_rules(project):
      rules = self._project_find_matching_rules(project)
      a = self._project_compute_rules(project, rules)
      actions.extend(a)
    thread_pool(num_worker_threads, compute_rules, projects, num_retries=2)
    self.actions = actions

    for action in self.actions:
      print(action)
      if self.actually_apply or isinstance(action, Error):
        action.invoke()
        self.executed.append(action)

  def summary(self):
    ret = 0
    print('computed {} actions'.format(len(self.actions)))
    if self.actions:
        if self.actually_apply:
          print('applied {} actions'.format(len(self.executed)))
        else:
          ret = 10
    failed = [action for action in self.executed if action.is_error()]
    if failed:
      print('found {} errors'.format(len(failed)))
      for action in failed:
        print("{}: {}".format(action, action.ret))
      ret = 1
    return ret

  def fetch_projects(self, filterglob=None, **kwargs):
    if filterglob and not is_glob(filterglob):
      return [self.gl.projects.get(filterglob)]

    num_worker_threads = 10
    per_page = 100
    projects = []
    def fetch_page(page):
      p = self.gl.projects.list(per_page=per_page, page=page, **kwargs)
      projects.extend(p)
    total_pages = self.gl.projects.list(as_list=False, per_page=100, **kwargs).total_pages
    thread_pool(num_worker_threads, fetch_page, range(1, total_pages+1), num_retries=2)
    if filterglob:
      projects = (p for p in projects if self._project_match_filter(p, filterglob))
    return projects

  def list_projects(self, filterglob=None, archived=False):
    projects = self.fetch_projects(filterglob, archived=archived, query_data={'simple': True})
    for p in projects:
      print(p.id, p.path_with_namespace)

if __name__ == '__main__':
  parser = argparse.ArgumentParser(prog='gitlab-rulez')
  parser.add_argument(
    "--gitlab-instance",
    type=str,
    default="apertis",
    help="get connection parameters from this configured instance",
  )
  parser.add_argument("--gitlab-api-token", type=str, help="the GitLab API token")
  parser.add_argument("--gitlab-server-url", type=str, help="the GitLab instance URL")

  subparsers = parser.add_subparsers(dest='command', required=True)

  parser_list = subparsers.add_parser('list-projects', help='list the project paths')
  parser_list.add_argument('--filter', type=str, help='only lists objects matching the specified path glob')
  parser_list.add_argument('--archived', action='store_const', const=None, default=False, help='include archived projects')

  parser_diff = subparsers.add_parser('diff', help='check what the rules would do')
  parser_diff.add_argument('--filter', type=str, help='only act on objects matching the specified path glob')
  parser_diff.add_argument('--archived', action='store_const', const=None, default=False, help='include archived projects')
  parser_diff.add_argument('RULES', help='the YAML rules file')

  parser_apply = subparsers.add_parser('apply', help='apply the changes to GitLab')
  parser_apply.add_argument('--filter', type=str, help='only act on objects matching the specified path glob')
  parser_apply.add_argument('--archived', action='store_const', const=None, default=False, help='include archived projects')
  parser_apply.add_argument('RULES', help='the YAML rules file')

  args = parser.parse_args(sys.argv[1:])

  if args.command == 'list-projects':
    r = GitlabRulez(False)
    r.connect(args.gitlab_instance, args.gitlab_server_url, args.gitlab_api_token)
    r.list_projects(filterglob=args.filter, archived=args.archived)
    sys.exit(0)

  actually_apply = args.command == 'apply'
  r = GitlabRulez(actually_apply)
  r.load_rules(args.RULES)
  r.connect(args.gitlab_instance, args.gitlab_server_url, args.gitlab_api_token)
  r.apply_rules(filterglob=args.filter, archived=args.archived)
  ret = r.summary()
  sys.exit(ret)
