#!/usr/bin/env python
# Copyright 2015 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

from skypy.skyserver import SkyServer
import argparse
import hashlib
import json
import logging
import os
import pipes
import platform
import re
import signal
import subprocess
import sys
import tempfile
import time
import urlparse

SKY_TOOLS_DIR = os.path.dirname(os.path.abspath(__file__))
SKY_ROOT = os.path.dirname(SKY_TOOLS_DIR)
SRC_ROOT = os.path.dirname(SKY_ROOT)

GDB_PORT = 8888
SKY_SERVER_PORT = 9888
OBSERVATORY_PORT = 8181
DEFAULT_URL = "sky://domokit.github.io/home.dart"
APK_NAME = 'SkyDemo.apk'
ADB_PATH = os.path.join(SRC_ROOT,
    'third_party/android_tools/sdk/platform-tools/adb')
ANDROID_PACKAGE = "org.domokit.sky.demo"
SHA1_PATH = '/sdcard/%s/%s.sha1' % (ANDROID_PACKAGE, APK_NAME)

PID_FILE_PATH = "/tmp/skydemo.pids"
PID_FILE_KEYS = frozenset([
    'remote_sky_server_port',
    'sky_server_pid',
    'sky_server_port',
    'sky_server_root',
    'build_dir',
    'sky_shell_pid',
    'remote_gdbserver_port',
])

SYSTEM_LIBS_ROOT_PATH = '/tmp/device_libs'

# This 'strict dictionary' approach is useful for catching typos.
class Pids(object):
    def __init__(self, known_keys, contents=None):
        self._known_keys = known_keys
        self._dict = contents if contents is not None else {}

    def __len__(self):
        return len(self._dict)

    def get(self, key, default=None):
        assert key in self._known_keys, '%s not in known_keys' % key
        return self._dict.get(key, default)

    def __getitem__(self, key):
        assert key in self._known_keys, '%s not in known_keys' % key
        return self._dict[key]

    def __setitem__(self, key, value):
        assert key in self._known_keys, '%s not in known_keys' % key
        self._dict[key] = value

    def __delitem__(self, key):
        assert key in self._known_keys, '%s not in known_keys' % key
        del self._dict[key]

    def __iter__(self):
        return iter(self._dict)

    def __contains__(self, key):
        assert key in self._known_keys, '%s not in allowed_keys' % key
        return key in self._dict

    def clear(self):
        self._dict = {}

    def pop(self, key, default=None):
        assert key in self._known_keys, '%s not in allowed_keys' % key
        return self._dict.pop(key, default)

    @classmethod
    def read_from(cls, path, known_keys):
        contents = {}
        try:
            with open(path, 'r') as pid_file:
                contents = json.load(pid_file)
        except:
            if os.path.exists(path):
                logging.warn('Failed to read pid file: %s' % path)
        return cls(known_keys, contents)

    def write_to(self, path):
        try:
            with open(path, 'w') as pid_file:
                json.dump(self._dict, pid_file, indent=2, sort_keys=True)
        except:
            logging.warn('Failed to write pid file: %s' % path)


def _convert_to_sky_url(url):
    parts = urlparse.urlsplit(url)
    parts = parts._replace(scheme='sky')
    return parts.geturl()


# A free function for possible future sharing with a 'load' command.
def _url_from_args(args, pids):
    if urlparse.urlparse(args.url_or_path).scheme:
        return args.url_or_path
    # The load happens on the remote device, use the remote port.
    remote_sky_server_port = pids.get('remote_sky_server_port',
        pids['sky_server_port'])
    url = SkyServer.url_for_path(remote_sky_server_port,
        pids['sky_server_root'], args.url_or_path)
    return _convert_to_sky_url(url)


def dev_packages_root(build_dir):
    return os.path.join(build_dir, 'gen', 'dart-pkg', 'packages')


def ensure_assets_are_downloaded(build_dir):
    sky_pkg_dir = os.path.join(build_dir, 'gen', 'dart-pkg', 'sky')
    sky_pkg_lib_dir = os.path.join(sky_pkg_dir, 'lib')
    sky_icons_dir = \
        os.path.join(sky_pkg_lib_dir, 'assets', 'material-design-icons')
    if not os.path.isdir(sky_icons_dir):
        logging.info('NOTE: sky/assets/material-design-icons missing, '
                     'Running `download_material_design_icons` for you.')
        subprocess.check_call(
            [os.path.join(sky_pkg_lib_dir, 'download_material_design_icons')])


class SetBuildDir(object):
    def add_subparser(self, subparsers):
        start_parser = subparsers.add_parser('set_build_dir',
            help='force the build_dir to a particular value without starting Sky')
        start_parser.add_argument('build_dir', type=str)
        start_parser.set_defaults(func=self.run)

    def run(self, args, pids):
        pids['build_dir'] = os.path.abspath(args.build_dir)


class StartSky(object):
    def add_subparser(self, subparsers):
        start_parser = subparsers.add_parser('start',
            help='launch SkyShell.apk on the device')
        start_parser.add_argument('build_dir', type=str)
        start_parser.add_argument('--gdb', action="store_true")
        start_parser.add_argument('url_or_path', nargs='?', type=str,
            default=DEFAULT_URL)
        start_parser.add_argument('--no_install', action="store_false",
            default=True, dest="install",
            help="Don't install SkyDemo.apk before starting")
        start_parser.set_defaults(func=self.run)

    def _server_root_for_url(self, url_or_path):
        path = os.path.abspath(url_or_path)
        if os.path.commonprefix([path, SRC_ROOT]) == SRC_ROOT:
            server_root = SRC_ROOT
        else:
            server_root = os.path.dirname(path)
            logging.warn(
                '%s is outside of mojo root, using %s as server root' %
                (path, server_root))
        return server_root

    def _sky_server_for_args(self, args, packages_root):
        server_root = self._server_root_for_url(args.url_or_path)
        sky_server = SkyServer(SKY_SERVER_PORT, server_root, packages_root)
        return sky_server

    def _find_remote_pid_for_package(self, package):
        ps_output = subprocess.check_output([ADB_PATH, 'shell', 'ps'])
        for line in ps_output.split('\n'):
            fields = line.split()
            if fields and fields[-1] == package:
                return fields[1]
        return None

    def _find_install_location_for_package(self, package):
        pm_command = [ADB_PATH, 'shell', 'pm', 'path', package]
        pm_output = subprocess.check_output(pm_command)
        # e.g. package:/data/app/org.chromium.mojo.shell-1/base.apk
        return pm_output.split(':')[-1]

    def run(self, args, pids):
        apk_path = os.path.join(args.build_dir, 'apks', APK_NAME)
        if not os.path.exists(apk_path):
            print "'%s' does not exist?" % apk_path
            return 2

        ensure_assets_are_downloaded(args.build_dir)

        packages_root = dev_packages_root(args.build_dir)
        sky_server = self._sky_server_for_args(args, packages_root)
        pids['sky_server_pid'] = sky_server.start()
        pids['sky_server_port'] = sky_server.port
        pids['sky_server_root'] = sky_server.root

        pids['build_dir'] = os.path.abspath(args.build_dir)

        if args.install:
            # We might need to install a new APK, so check SHA1
            source_sha1 = hashlib.sha1(open(apk_path, 'rb').read()).hexdigest()
            dest_sha1 = subprocess.check_output([ADB_PATH, 'shell', 'cat', SHA1_PATH])
            use_existing_apk = False
            if source_sha1 == dest_sha1:
                # Make sure that the APK didn't get uninstalled somehow
                use_existing_apk = subprocess.check_output([
                    ADB_PATH, 'shell', 'pm', 'list', 'packages', ANDROID_PACKAGE
                ])
        else:
            # User is telling us not to bother installing an APK
            use_existing_apk = True

        if use_existing_apk:
            # APK is already on the device, we only need to stop it
            subprocess.check_call([
                ADB_PATH, 'shell', 'am', 'force-stop', ANDROID_PACKAGE
            ])
        else:
            # Slow path, need to upload a new APK to the device
            # -r to replace an existing apk, -d to allow version downgrade.
            subprocess.check_call([ADB_PATH, 'install', '-r', '-d', apk_path])
            # record the SHA1 of the APK we just pushed
            with tempfile.NamedTemporaryFile() as fp:
                fp.write(source_sha1)
                fp.seek(0)
                subprocess.check_call([ADB_PATH, 'push', fp.name, SHA1_PATH])

        # Set up port forwarding for observatory
        port_string = 'tcp:%s' % OBSERVATORY_PORT
        subprocess.check_call([
            ADB_PATH, 'forward', port_string, port_string
        ])

        port_string = 'tcp:%s' % sky_server.port
        subprocess.check_call([
            ADB_PATH, 'reverse', port_string, port_string
        ])
        pids['remote_sky_server_port'] = sky_server.port


        subprocess.check_call([ADB_PATH, 'shell',
            'am', 'start',
            '-a', 'android.intent.action.VIEW',
            '-d', _url_from_args(args, pids)])

        if not args.gdb:
            return

        # TODO(eseidel): am start -W does not seem to work?
        pid_tries = 0
        while True:
            pid = self._find_remote_pid_for_package(ANDROID_PACKAGE)
            if pid or pid_tries > 3:
                break
            logging.debug('No pid for %s yet, waiting' % ANDROID_PACKAGE)
            time.sleep(5)
            pid_tries += 1

        if not pid:
            logging.error('Failed to find pid on device!')
            return

        pids['sky_shell_pid'] = pid

        # We push our own copy of gdbserver with the package since
        # the default gdbserver is a different version from our gdb.
        package_path = \
            self._find_install_location_for_package(ANDROID_PACKAGE)
        gdb_server_path = os.path.join(
            os.path.dirname(package_path), 'lib/arm/gdbserver')
        gdbserver_cmd = [
            ADB_PATH, 'shell',
            gdb_server_path, '--attach',
            ':%d' % GDB_PORT,
            str(pid)
        ]
        print ' '.join(map(pipes.quote, gdbserver_cmd))
        subprocess.Popen(gdbserver_cmd)

        port_string = 'tcp:%d' % GDB_PORT
        subprocess.check_call([
            ADB_PATH, 'forward', port_string, port_string
        ])
        pids['remote_gdbserver_port'] = GDB_PORT


class GDBAttach(object):
    def add_subparser(self, subparsers):
        start_parser = subparsers.add_parser('gdb_attach',
            help='attach to gdbserver running on device')
        start_parser.set_defaults(func=self.run)

    def _pull_system_libraries(self, pids, system_libs_root):
        # Pull down the system libraries this pid has already mapped in.
        # TODO(eseidel): This does not handle dynamic loads.
        library_cacher_path = os.path.join(
            SKY_TOOLS_DIR, 'android_library_cacher.py')
        subprocess.call([
            library_cacher_path, system_libs_root, pids['sky_shell_pid']
        ])

        # TODO(eseidel): adb_gdb does, this, unclear why solib-absolute-prefix
        # doesn't make this explicit listing not necessary?
        return subprocess.check_output([
            'find', system_libs_root,
            '-mindepth', '1',
            '-maxdepth', '4',
            '-type', 'd',
        ]).strip().split('\n')

    def run(self, args, pids):
        symbol_search_paths = [
            pids['build_dir'],
        ]
        gdb_path = '/usr/bin/gdb'

        eval_commands = [
            'directory %s' % SRC_ROOT,
            # TODO(eseidel): What file do I point it at?  The apk?
            #'file %s' % self.paths.mojo_shell_path,
            'target remote localhost:%s' % GDB_PORT,
        ]

        # TODO(iansf): Fix undefined behavior when you have more than one device attached.
        device_id = subprocess.check_output([ADB_PATH, 'get-serialno']).strip()
        device_libs_path = os.path.join(SYSTEM_LIBS_ROOT_PATH, device_id)

        system_lib_dirs = self._pull_system_libraries(pids, device_libs_path)
        eval_commands.append('set solib-absolute-prefix %s' % device_libs_path)

        symbol_search_paths = system_lib_dirs + symbol_search_paths

        # TODO(eseidel): We need to look up the toolchain somehow?
        if platform.system() == 'Darwin':
            gdb_path = os.path.join(SRC_ROOT, 'third_party/android_tools/ndk/'
                'toolchains/arm-linux-androideabi-4.9/prebuilt/darwin-x86_64/'
                'bin/arm-linux-androideabi-gdb')
        else:
            gdb_path = os.path.join(SRC_ROOT, 'third_party/android_tools/ndk/'
                'toolchains/arm-linux-androideabi-4.9/prebuilt/linux-x86_64/'
                'bin/arm-linux-androideabi-gdb')

        # Set solib-search-path after letting android modify symbol_search_paths
        eval_commands.append(
            'set solib-search-path %s' % ':'.join(symbol_search_paths))

        exec_command = [gdb_path]
        for command in eval_commands:
            exec_command += ['--eval-command', command]

        print " ".join(exec_command)

        # Write out our pid file before we exec ourselves.
        pids.write_to(PID_FILE_PATH)

        # Exec gdb directly to avoid python intercepting symbols, etc.
        os.execv(exec_command[0], exec_command)



class StopSky(object):
    def add_subparser(self, subparsers):
        stop_parser = subparsers.add_parser('stop',
            help=('kill all running SkyShell.apk processes'))
        stop_parser.set_defaults(func=self.run)

    def _kill_if_exists(self, pids, key, name):
        pid = pids.pop(key, None)
        if not pid:
            logging.info('No pid for %s, nothing to do.' % name)
            return
        logging.info('Killing %s (%d).' % (name, pid))
        try:
            os.kill(pid, signal.SIGTERM)
        except OSError:
            logging.info('%s (%d) already gone.' % (name, pid))

    def _adb_reverse_remove(self, port):
        port_string = 'tcp:%s' % port
        subprocess.call([ADB_PATH, 'reverse', '--remove', port_string])

    def _adb_forward_remove(self, port):
        port_string = 'tcp:%s' % port
        subprocess.call([ADB_PATH, 'forward', '--remove', port_string])

    def run(self, args, pids):
        self._kill_if_exists(pids, 'sky_server_pid', 'sky_server')

        if 'remote_sky_server_port' in pids:
            self._adb_reverse_remove(pids['remote_sky_server_port'])

        if 'remote_gdbserver_port' in pids:
            self._kill_if_exists('adb_shell_gdbserver_pid',
                'adb shell gdbserver')
            self._adb_forward_remove(pids['remote_gdbserver_port'])

        subprocess.call([
            ADB_PATH, 'shell', 'am', 'force-stop', ANDROID_PACKAGE])

        pids.clear()


class Analyze(object):
    def add_subparser(self, subparsers):
        analyze_parser = subparsers.add_parser('analyze',
            help=('run the dart analyzer with sky url mappings'))
        analyze_parser.add_argument('app_path', type=str)
        analyze_parser.add_argument('--congratulate', action="store_true")
        analyze_parser.set_defaults(func=self.run)

    def run(self, args, pids):
        build_dir = pids.get('build_dir')
        if not build_dir:
            logging.fatal("pids file missing build_dir. Try 'start' first.")
            return 2
        analyzer_path = os.path.join(SRC_ROOT, 'sky/tools/skyanalyzer')
        analyzer_args = [
            analyzer_path,
            build_dir,
            args.app_path
        ]
        if args.congratulate:
          analyzer_args.append('--congratulate')
        try:
          output = subprocess.check_output(analyzer_args, stderr=subprocess.STDOUT)
          result = 0
        except subprocess.CalledProcessError as e:
          output = e.output
          result = e.returncode
        lines = output.split('\n')
        lines.pop()
        for line in lines:
          print >> sys.stderr, line
        return result


class StartTracing(object):
    def add_subparser(self, subparsers):
        start_tracing_parser = subparsers.add_parser('start_tracing',
            help=('start tracing a running sky instance'))
        start_tracing_parser.set_defaults(func=self.run)

    def run(self, args, pids):
        subprocess.check_output([ADB_PATH, 'shell',
            'am', 'broadcast',
            '-a', 'org.domokit.sky.demo.TRACING_START'])


TRACE_COMPLETE_REGEXP = re.compile('Trace complete')
TRACE_FILE_REGEXP = re.compile(r'Saving trace to (?P<path>\S+)')

class StopTracing(object):
    def add_subparser(self, subparsers):
        stop_tracing_parser = subparsers.add_parser('stop_tracing',
            help=('stop tracing a running sky instance'))
        stop_tracing_parser.set_defaults(func=self.run)

    def run(self, args, pids):
        subprocess.check_output([ADB_PATH, 'logcat', '-c'])
        subprocess.check_output([ADB_PATH, 'shell',
            'am', 'broadcast',
            '-a', 'org.domokit.sky.demo.TRACING_STOP'])
        device_path = None
        is_complete = False
        while not is_complete:
            time.sleep(0.2)
            log = subprocess.check_output([ADB_PATH, 'logcat', '-d'])
            if device_path is None:
                result = TRACE_FILE_REGEXP.search(log)
                if result:
                    device_path = result.group('path')
            is_complete = TRACE_COMPLETE_REGEXP.search(log) is not None

        print 'Downloading trace %s ...' % os.path.basename(device_path)

        if device_path:
            subprocess.check_output([ADB_PATH, 'pull', device_path])
            subprocess.check_output([ADB_PATH, 'shell', 'rm', device_path])


class SkyShellRunner(object):
    def main(self):
        logging.basicConfig(level=logging.WARNING)

        parser = argparse.ArgumentParser(description='Sky Shell Runner')
        subparsers = parser.add_subparsers(help='sub-command help')

        commands = [
            SetBuildDir(),
            StartSky(),
            StopSky(),
            Analyze(),
            GDBAttach(),
            StartTracing(),
            StopTracing(),
        ]

        for command in commands:
            command.add_subparser(subparsers)

        args = parser.parse_args()
        pids = Pids.read_from(PID_FILE_PATH, PID_FILE_KEYS)
        exit_code = args.func(args, pids)
        # We could do this with an at-exit handler instead?
        pids.write_to(PID_FILE_PATH)
        sys.exit(exit_code)


if __name__ == '__main__':
    SkyShellRunner().main()
