#!/usr/bin/env python
# Copyright 2015 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

import argparse
import json
import logging
import os
import re
import signal
import socket
import subprocess
import sys
import urlparse
import time

# TODO(eseidel): This should be BIN_DIR.
LIB_DIR = os.path.realpath(os.path.dirname(os.path.abspath(__file__)))
SKY_PACKAGE_ROOT = os.path.realpath(os.path.dirname(LIB_DIR))

SKY_SERVER_PORT = 9888
APK_NAME = 'SkyDemo.apk'
ANDROID_PACKAGE = "org.domokit.sky.demo"
# FIXME: This assumes adb is in $PATH, we could look for ANDROID_HOME, etc?
ADB_PATH = 'adb'
# FIXME: Do we need to look in $DART_SDK?
DART_PATH = 'dart'

PID_FILE_PATH = "/tmp/sky_tool.pids"
PID_FILE_KEYS = frozenset([
    'remote_sky_server_port',
    'sky_server_pid',
    'sky_server_port',
    'sky_server_root',
    'build_dir',
])


def _port_in_use(port):
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    return sock.connect_ex(('localhost', port)) == 0


# We need something to serve dart files, python's httpserver is sufficient.
def _start_http_server(port, root):
    server_command = [
        'python', '-m', 'SimpleHTTPServer', str(port),
    ]
    return subprocess.Popen(server_command, cwd=root).pid


# This 'strict dictionary' approach is useful for catching typos.
class Pids(object):
    def __init__(self, known_keys, contents=None):
        self._known_keys = known_keys
        self._dict = contents if contents is not None else {}

    def __len__(self):
        return len(self._dict)

    def get(self, key, default=None):
        assert key in self._known_keys, '%s not in known_keys' % key
        return self._dict.get(key, default)

    def __getitem__(self, key):
        assert key in self._known_keys, '%s not in known_keys' % key
        return self._dict[key]

    def __setitem__(self, key, value):
        assert key in self._known_keys, '%s not in known_keys' % key
        self._dict[key] = value

    def __delitem__(self, key):
        assert key in self._known_keys, '%s not in known_keys' % key
        del self._dict[key]

    def __iter__(self):
        return iter(self._dict)

    def __contains__(self, key):
        assert key in self._known_keys, '%s not in allowed_keys' % key
        return key in self._dict

    def clear(self):
        self._dict = {}

    def pop(self, key, default=None):
        assert key in self._known_keys, '%s not in known_keys' % key
        return self._dict.pop(key, default)

    @classmethod
    def read_from(cls, path, known_keys):
        contents = {}
        try:
            with open(path, 'r') as pid_file:
                contents = json.load(pid_file)
        except:
            if os.path.exists(path):
                logging.warn('Failed to read pid file: %s' % path)
        return cls(known_keys, contents)

    def write_to(self, path):
        try:
            with open(path, 'w') as pid_file:
                json.dump(self._dict, pid_file, indent=2, sort_keys=True)
        except:
            logging.warn('Failed to write pid file: %s' % path)


def _url_for_path(port, root, path):
    relative_path = os.path.relpath(path, root)
    return 'sky://localhost:%s/%s' % (port, relative_path)


class StartSky(object):
    def add_subparser(self, subparsers):
        start_parser = subparsers.add_parser('start',
            help='launch %s on the device' % APK_NAME)
        start_parser.add_argument('--install', action='store_true')
        start_parser.add_argument('project_or_path', nargs='?', type=str,
            default='main.dart')
        start_parser.set_defaults(func=self.run)

    def _is_package_installed(self, package_name):
        pm_path_cmd = [ADB_PATH, 'shell', 'pm', 'path', package_name]
        return subprocess.check_output(pm_path_cmd).strip() != ''

    def run(self, args, pids):
        StopSky().run(args, pids)

        project_or_path = os.path.abspath(args.project_or_path)

        if os.path.isdir(project_or_path):
            sky_server_root = project_or_path
            main_dart = os.path.join(project_or_path, 'main.dart')
            missing_msg = "Missing main.dart in project: %s" % sky_server_root
        else:
            # FIXME: This assumes the path is at the root of the project!
            # Instead we should walk up looking for a pubspec.yaml
            sky_server_root = os.path.dirname(project_or_path)
            main_dart = project_or_path
            missing_msg = "%s does not exist." % main_dart

        if not os.path.isfile(main_dart):
            print missing_msg
            return 2

        package_root = os.path.join(sky_server_root, 'packages')
        if not os.path.isdir(package_root):
            print "%s is not a valid packages path." % package_root
            return 2

        subprocess.check_call([
            DART_PATH,
            '--package-root=%s' % package_root,
            'packages/mojom/generate.dart'
        ])

        if not self._is_package_installed(ANDROID_PACKAGE):
            print '%s is not installed, installing.' % APK_NAME
            args.install = True

        if args.install:
            apk_path = os.path.join(SKY_PACKAGE_ROOT, 'apks', APK_NAME)
            if not os.path.exists(apk_path):
                print "'%s' does not exist?" % apk_path
                return 2

            subprocess.check_call([ADB_PATH, 'install', '-r', apk_path])

        sky_server_port = SKY_SERVER_PORT
        pids['sky_server_port'] = sky_server_port
        if _port_in_use(sky_server_port):
            logging.warn(('Port %s already in use. '
            ' Not starting server for %s') % (sky_server_port, sky_server_root))
        else:
            sky_server_pid = _start_http_server(sky_server_port, sky_server_root)
            pids['sky_server_pid'] = sky_server_pid
            pids['sky_server_root'] = sky_server_root

        port_string = 'tcp:%s' % sky_server_port
        subprocess.check_call([
            ADB_PATH, 'reverse', port_string, port_string
        ])
        pids['remote_sky_server_port'] = sky_server_port

        # The load happens on the remote device, use the remote port.
        sky_url = _url_for_path(pids['remote_sky_server_port'], sky_server_root,
            main_dart)

        subprocess.check_call([ADB_PATH, 'shell',
            'am', 'start',
            '-a', 'android.intent.action.VIEW',
            '-d', sky_url])


class StopSky(object):
    def add_subparser(self, subparsers):
        stop_parser = subparsers.add_parser('stop',
            help=('kill all running SkyShell.apk processes'))
        stop_parser.set_defaults(func=self.run)

    def _kill_if_exists(self, pids, key, name):
        pid = pids.pop(key, None)
        if not pid:
            logging.info('No pid for %s, nothing to do.' % name)
            return
        logging.info('Killing %s (%d).' % (name, pid))
        try:
            os.kill(pid, signal.SIGTERM)
        except OSError:
            logging.info('%s (%d) already gone.' % (name, pid))

    def run(self, args, pids):
        self._kill_if_exists(pids, 'sky_server_pid', 'sky_server')

        if 'remote_sky_server_port' in pids:
            port_string = 'tcp:%s' % pids['remote_sky_server_port']
            subprocess.call([ADB_PATH, 'reverse', '--remove', port_string])

        subprocess.call([
            ADB_PATH, 'shell', 'am', 'force-stop', ANDROID_PACKAGE])

        pids.clear()


class StartTracing(object):
    def add_subparser(self, subparsers):
        start_tracing_parser = subparsers.add_parser('start_tracing',
            help=('start tracing a running sky instance'))
        start_tracing_parser.set_defaults(func=self.run)

    def run(self, args, pids):
        subprocess.check_output([ADB_PATH, 'shell',
            'am', 'broadcast',
            '-a', 'org.domokit.sky.demo.TRACING_START'])


TRACE_COMPLETE_REGEXP = re.compile('Trace complete')
TRACE_FILE_REGEXP = re.compile(r'Saving trace to (?P<path>\S+)')


class StopTracing(object):
    def add_subparser(self, subparsers):
        stop_tracing_parser = subparsers.add_parser('stop_tracing',
            help=('stop tracing a running sky instance'))
        stop_tracing_parser.set_defaults(func=self.run)

    def run(self, args, pids):
        subprocess.check_output([ADB_PATH, 'logcat', '-c'])
        subprocess.check_output([ADB_PATH, 'shell',
            'am', 'broadcast',
            '-a', 'org.domokit.sky.demo.TRACING_STOP'])
        device_path = None
        is_complete = False
        while not is_complete:
            time.sleep(0.2)
            log = subprocess.check_output([ADB_PATH, 'logcat', '-d'])
            if device_path is None:
                result = TRACE_FILE_REGEXP.search(log)
                if result:
                    device_path = result.group('path')
            is_complete = TRACE_COMPLETE_REGEXP.search(log) is not None

        print 'Downloading trace %s ...' % os.path.basename(device_path)

        if device_path:
            subprocess.check_output([ADB_PATH, 'pull', device_path])
            subprocess.check_output([ADB_PATH, 'shell', 'rm', device_path])


class SkyShellRunner(object):
    def _check_for_adb(self):
        try:
            subprocess.check_output([ADB_PATH, 'devices'])
        except OSError:
            print "'adb' (from the Android SDK) not in $PATH, can't continue."
            return False
        return True

    def _check_for_dart(self):
        try:
            subprocess.check_output([DART_PATH, '--version'])
        except OSError:
            print "'dart' (from the Dart SDK) not in $PATH, can't continue."
            return False
        return True

    def main(self):
        logging.basicConfig(level=logging.WARNING)
        if not self._check_for_adb() or not self._check_for_dart():
            sys.exit(2)

        parser = argparse.ArgumentParser(description='Sky Demo Runner')
        subparsers = parser.add_subparsers(help='sub-command help')

        for command in [StartSky(), StopSky(), StartTracing(), StopTracing()]:
            command.add_subparser(subparsers)

        args = parser.parse_args()
        pids = Pids.read_from(PID_FILE_PATH, PID_FILE_KEYS)
        exit_code = args.func(args, pids)
        # We could do this with an at-exit handler instead?
        pids.write_to(PID_FILE_PATH)
        sys.exit(exit_code)


if __name__ == '__main__':
    SkyShellRunner().main()
