#!/usr/bin/env python3

import requests
import yaml
import json
import os
import shutil
import logging
import time
import argparse
import subprocess
import hashlib
import daemon

logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s', level=logging.INFO)

class ClashUp:

    EXAMPLE_CONF = '''{
    "http_port": 7890,
    "socks5_port": 7891,
    "redir_port": 7892,
    "mixed_port": 7893,
    "allow_lan": true,
    "external_controller": "127.0.0.1:9090",
    "external_ui": "/usr/share/yacd",
    "subscribe_url": "",
    "is_subscribe_banned": false,
    "custom_rules": [],
    "custom_proxies": [],
    "custom_groups": [],
    "mmdb_file_url": "http://www.ideame.top/mmdb/Country.mmdb",
    "mmdb_version_url": "http://www.ideame.top/mmdb/version",
    "periodically_update": false
}
'''
    def __init__(self):
        self.conf_file_path = os.path.expanduser('~/.config/clash/clashup.json')
        self.clash_conf_path = os.path.expanduser('~/.config/clash/config.yaml')
        self.clash_conf_path_old = os.path.expanduser('~/.config/clash/config.yaml.old')
        self.cache_file_path = os.path.expanduser('~/.cache/clashup')
        self.mmdb_version_file = os.path.expanduser('~/.cache/clashup-mmdb')
        self.mmdb_file_path = os.path.expanduser('~/.config/clash/Country.mmdb')
        self.session = requests.Session()
        self.session.trust_env = False

    def load_conf(self):
        if not os.path.isfile(self.conf_file_path):
            with open(self.conf_file_path, 'w') as f:
                f.write(self.EXAMPLE_CONF)
            raise OSError('plz edit ~/.config/clash/clashup.json')
        with open(self.conf_file_path) as f:
            raw_config_text = f.read()
            raw_config = json.loads(raw_config_text)
        if not raw_config.get('subscribe_url'):
            raise ValueError('subscribe_url can not be empty')
        self.config = raw_config
        hash_item = hashlib.sha256(raw_config_text.encode())
        self.config_hash = hash_item.hexdigest()

    def download(self, use_proxy=False):
        if use_proxy:
            proxy = {
                'http': 'http://127.0.0.1:{}'.format(self.config['http_port']),
                'https': 'http://127.0.0.1:{}'.format(self.config['http_port'])
            }
            res = self.session.get(self.config['subscribe_url'], timeout=5, proxies=proxy)
        else:
            res = self.session.get(self.config['subscribe_url'], timeout=5, proxies={'http': None, 'https': None})
        res.raise_for_status()
        raw_clash_conf = yaml.safe_load(res.text)
        return raw_clash_conf
    
    def _load_conf(self, config, local_config_key, config_key):
        if local_config_key in self.config:
            config[config_key] = self.config[local_config_key]

    def parse_config(self, config):
        self._load_conf(config, 'http_port', 'port')
        self._load_conf(config, 'socks5_port', 'socks-port')
        self._load_conf(config, 'redir_port', 'redir-port')
        self._load_conf(config, 'mixed_port', 'mixed-port')
        self._load_conf(config, 'allow_lan', 'allow-lan')
        self._load_conf(config, 'external_controller', 'external-controller')
        self._load_conf(config, 'external_ui', 'external-ui')
        config['rules'] = self.config.get('custom_rules', []) + config.get('rules', [])
        config['proxies'] = self.config.get('custom_proxies', []) + config.get('proxies', [])
        config['proxy-groups'] = self.config.get('custom_groups', []) + config.get('proxy-groups', [])
        return config

    def save(self, config):
        if os.path.isfile(self.clash_conf_path):
            shutil.move(self.clash_conf_path, self.clash_conf_path_old)
        with open(self.clash_conf_path, 'w') as f:
            f.write(yaml.safe_dump(config, allow_unicode=True))

    def update(self, use_proxy):
        logging.info('Update Start')
        try:
            raw_clash_conf = self.download(use_proxy)
            parsed_clash_conf = self.parse_config(raw_clash_conf)
            self.save(parsed_clash_conf)
            logging.info('Update Finish')
        except requests.exceptions.RequestException:
            logging.warning('Update Fail')

    def update_mmdb(self):
        try:
            resp = self.session.get(self.config['mmdb_version_url'], proxies={'http': None, 'https': None})
            resp.raise_for_status()
            current_version = resp.text
            if os.path.isfile(self.mmdb_version_file):
                with open(self.mmdb_version_file, 'r') as f:
                    if current_version == f.read():
                        logging.info('pass mmdb update')
                        return
            resp = self.session.get(self.config['mmdb_file_url'], proxies={'http': None, 'https': None})
            resp.raise_for_status()
            with open(self.mmdb_file_path, 'wb') as f:
                f.write(resp.content)
            with open(self.mmdb_version_file, 'w') as f:
                f.write(current_version)
            logging.info('Update mmdb')
        except requests.exceptions.RequestException:
            logging.warning('Update mmdb failed')

    def update_time_cache(self):
        if not os.path.isfile(self.cache_file_path):
            self._write_cache()
            return True
        else:
            with open(self.cache_file_path, 'r') as f:
                cache_text = f.read().split('-')
            last_time = float(cache_text[1])
            if cache_text[0] != self.config_hash or time.time() - last_time > 86400:
                self._write_cache()
                return True
            else:
                return False

    def _write_cache(self):
        with open(self.cache_file_path, 'w') as f:
            f.write('{}-{}'.format(self.config_hash, time.time()))

    def run(self):
        parser = argparse.ArgumentParser()
        parser.add_argument('--pre', action='store_true')
        parser.add_argument('--post', action='store_true')
        parser.add_argument('--update', action='store_true')
        args = parser.parse_args()
        self.load_conf()
        if args.update:
            self.update(False)
            self.update_mmdb()
        elif args.pre and not self.config.get('periodically_update', False):
            if self.config['is_subscribe_banned']:
                logging.info('Subscribe is banned, pass this run')
            else:
                self.update(False)
            if self.config.get('mmdb_version_url') and self.config.get('mmdb_file_url'):
                with daemon.DaemonContext():
                    self.update_mmdb()
        elif args.post and not self.config.get('periodically_update', False):
            if self.config['is_subscribe_banned']:
                if self.update_time_cache():
                    self.update(True)
                    subprocess.run(['systemctl', '--user', 'restart', 'clash'])
                else:
                    logging.info('config file updated in 24h, pass this run')
            else:
                logging.info('pass this run')


if __name__ == '__main__':
    ClashUp().run()