#!/usr/bin/python

import sys, os, glob, os.path, shutil, re

glob = glob.glob

def cmd(c):
    if os.system(c) != 0:
        raise Exception('command execution failed: ' + c)

version = 'kvm-devel'
if len(sys.argv) >= 2:
    version = sys.argv[1]

linux = 'linux-2.6'

_re_cache = {}

def re_cache(regexp):
    global _re_cache
    if regexp not in _re_cache:
        _re_cache[regexp] = re.compile(regexp)
    return _re_cache[regexp]

def __hack(data):
    compat_apis = str.split(
        'INIT_WORK desc_struct ldttss_desc64 desc_ptr '
        'hrtimer_add_expires_ns hrtimer_get_expires '
        'hrtimer_get_expires_ns hrtimer_start_expires '
        'hrtimer_expires_remaining smp_send_reschedule '
        'on_each_cpu relay_open request_irq free_irq')
    anon_inodes = anon_inodes_exit = False
    result = []
    def sub(regexp, repl, str):
        return re_cache(regexp).sub(repl, str)
    for line in data.splitlines():
        orig = line
        def match(regexp):
            return re_cache(regexp).search(line)
        def w(line, result = result):
            result.append(line)
        f = line.split()
        if match(r'^int kvm_init\('): anon_inodes = 1
        if match(r'return 0;') and anon_inodes:
            w('\tr = kvm_init_anon_inodes();')
            w('\tif (r) {')
            w('\t\t__free_page(bad_page);')
            w('\t\tgoto out;')
            w('\t}')
            w('\tpreempt_notifier_sys_init();')
            w('\tprintk("loaded kvm module (%s)\\n");\n' % (version,))
            anon_inodes = False
        if match(r'^void kvm_exit'): anon_inodes_exit = True
        if match(r'\}') and anon_inodes_exit:
            w('\tkvm_exit_anon_inodes();')
            w('\tpreempt_notifier_sys_exit();')
            anon_inodes_exit = False
        if match(r'^int kvm_arch_init'): kvm_arch_init = True
        if match(r'\btsc_khz\b') and kvm_arch_init:
            line = sub(r'\btsc_khz\b', 'kvm_tsc_khz', line)
        if match(r'^}'): kvm_arch_init = False
        if match(r'MODULE_AUTHOR'):
            w('MODULE_INFO(version, "%s");' % (version,))
        line = sub(r'match->dev->msi_enabled',
                      'kvm_pcidev_msi_enabled(match->dev)', line)
        if match(r'atomic_inc\(&kvm->mm->mm_count\);'):
            line = 'mmget(&kvm->mm->mm_count);'
        if match(r'^\t\.fault = '):
            fcn = sub(r',', '', f[2])
            line = '\t.VMA_OPS_FAULT(fault) = VMA_OPS_FAULT_FUNC(' + fcn + '),'
        if match(r'^static int (.*_stat_get|lost_records_get)'):
            line = line[0:11] + '__' + line[11:]
        if match(r'DEFINE_SIMPLE_ATTRIBUTE.*(_stat_get|lost_records_get)'):
            name = sub(r',', '', f[1])
            w('MAKE_SIMPLE_ATTRIBUTE_GETTER(' + name + ')')
        line = sub(r'linux/mm_types\.h', 'linux/mm.h', line)
        line = sub(r'\b__user\b', ' ', line)
        if match(r'^\t\.name = "kvm"'):
            line = '\tset_kset_name("kvm"),'
        if match(r'#include <linux/compiler.h>'):
            line = ''
        if match(r'#include <linux/clocksource.h>'):
            line = ''
        if match(r'#include <linux\/types.h>'):
            line = '#include <asm/types.h>'
        line = sub(r'\bhrtimer_init\b', 'hrtimer_init_p', line)
        line = sub(r'\bhrtimer_start\b', 'hrtimer_start_p', line)
        line = sub(r'\bhrtimer_cancel\b', 'hrtimer_cancel_p', line)
        if match(r'case KVM_CAP_SYNC_MMU'):
            line = '#ifdef CONFIG_MMU_NOTIFIER\n' + line + '\n#endif'
        for ident in compat_apis:
            line = sub(r'\b' + ident + r'\b', 'kvm_' + ident, line)
        if match(r'kvm_.*_fops\.owner = module;'):
            line = 'IF_ANON_INODES_DOES_REFCOUNTS(' + line + ')'
        if not match(r'#include'):
            line = sub(r'\blapic\n', 'l_apic', line)
        w(line)
        if match(r'\tkvm_init_debug'):
            w('\thrtimer_kallsyms_resolve();')
        if match(r'apic->timer.dev.function ='):
            w('\thrtimer_data_pointer(&apic->timer.dev);')
        if match(r'pt->timer.function ='):
            w('\thrtimer_data_pointer(&pt->timer);')
    data = str.join('', [line + '\n' for line in result])
    return data

def _hack(fname, arch):
    data = file(fname).read()
    data = __hack(data)
    file(fname, 'w').write(data)

def unifdef(fname):
    data = file('unifdef.h').read() + file(fname).read()
    file(fname, 'w').write(data)

def hack(T, arch, file):
    _hack(T + '/' + file, arch)

hack_files = {
    'x86': str.split('kvm_main.c mmu.c vmx.c svm.c x86.c irq.h lapic.c'
                     ' i8254.c kvm_trace.c timer.c'),
    'ia64': str.split('kvm_main.c kvm_fw.c kvm_lib.c kvm-ia64.c'),
}

def mkdir(dir):
    if not os.path.exists(dir):
        os.makedirs(dir)

def cp(src, dst):
    mkdir(os.path.dirname(dst))
    file(dst, 'w').write(file(src).read())

def copy_if_changed(src, dst):
    for dir, subdirs, files in os.walk(src):
        ndir = dst + '/' + dir[len(src)+1:]
        mkdir(ndir)
        for fname in files:
            old = ndir + '/' + fname
            new = dir + '/' + fname
            try:
                if file(old).read() !=  file(new).read():
                    raise Exception('different.')
            except:
                cp(new, old)

def rmtree(path):
    if os.path.exists(path):
        shutil.rmtree(path)

def header_sync(arch):
    T = 'header'
    rmtree(T)
    for file in glob('%(linux)s/include/linux/kvm*.h' % { 'linux': linux }):
        out = ('%(T)s/include/linux/%(name)s'
               % { 'T': T, 'name': os.path.basename(file) })
        cp(file, out)
        unifdef(out)
    arch_headers = (
        [x
         for dir in ['%(linux)s/arch/%(arch)s/include/asm/./kvm*.h',
                     '%(linux)s/arch/%(arch)s/include/asm/./vmx*.h',
                     '%(linux)s/arch/%(arch)s/include/asm/./svm*.h',
                     '%(linux)s/arch/%(arch)s/include/asm/./virtext*.h']
         for x in glob(dir % { 'arch': arch, 'linux': linux })
         ])
    for file in arch_headers:
        out = ('%(T)s/include/asm-%(arch)s/%(name)s'
               % { 'T': T, 'name': os.path.basename(file), 'arch': arch })
        cp(file, out)
        unifdef(out)
    hack(T, 'x86', 'include/linux/kvm.h')
    hack(T, arch, 'include/asm-%(arch)s/kvm.h' % { 'arch': arch })
    copy_if_changed(T, '.')
    rmtree(T)

def source_sync(arch):
    T = 'source'
    rmtree(T)
    sources = [file
               for pattern in ['%(linux)s/arch/%(arch)s/kvm/*.[cSh]',
                               '%(linux)s/virt/kvm/*.[cSh]']
               for file in glob(pattern % { 'linux': linux, 'arch': arch })
               if not file.endswith('.mod.c')
               ]
    for file in sources:
        out = ('%(T)s/%(name)s'
               % { 'T': T, 'name': os.path.basename(file) })
        cp(file, out)

    for i in glob(T + '/*.c'):
        unifdef(i)

    for i in hack_files[arch]:
        hack(T, arch, i)

    copy_if_changed(T, arch)
    rmtree(T)

for arch in ['x86', 'ia64']:
    header_sync(arch)
    source_sync(arch)
