#!/usr/bin/env python
# -*- coding: utf-8 -*-
# $Id: Patch.py 10932 2018-03-26 07:56:30Z Kevin $
#
# Copyright (c) 2015 Nuwa Information Co., Ltd, and individual contributors.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
#   1. Redistributions of source code must retain the above copyright notice,
#      this list of conditions and the following disclaimer.
#
#   2. Redistributions in binary form must reproduce the above copyright
#      notice, this list of conditions and the following disclaimer in the
#      documentation and/or other materials provided with the distribution.
#
#   3. Neither the name of Nuwa Information nor the names of its contributors
#      may be used to endorse or promote products derived from this software
#      without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# $Author: Kevin $ (last)
# $Date: 2018-03-26 15:56:30 +0800 (一, 26  3 2018) $
# $Revision: 10932 $

from __future__ import print_function

import os
import shutil
import hashlib
import zlib
import base64
import re

from lxml import etree, objectify
from logilab.astng import builder as buildmod

from Iuppiter.Function import enhance

from Iuppiter.Util import colored
from Iuppiter.Util import cd
from Iuppiter.Util import classForName

# Ironically, We use logilab.astng as tool to patch libraries, but astng 
# itself should be patched first.
#============================================================================
from logilab.astng.node_classes import Arguments, _format_args

def __formatArgs(f, self):
    result = [_format_args(self.args, self.defaults)]
    if self.vararg:
        result.append(('*%s') % (self.vararg))
    if self.kwarg:
        result.append(('**%s') % (self.kwarg))
    return ', '.join([a for a in result if a])    
Arguments.format_args = enhance(Arguments.format_args, __formatArgs)

from logilab.astng.as_string import AsStringVisitor

def visitTuple(f, self, node):
    return '(%s,)' % ', '.join([child.accept(self) for child in node.elts])    
AsStringVisitor.visit_tuple = enhance(AsStringVisitor.visit_tuple, visitTuple)

def visitUnaryop(f, self, node):
    if node.op == 'not':
        operator = 'not '
    else:
        operator = node.op
    return '%s (%s)' % (operator, node.operand.accept(self))
AsStringVisitor.visit_unaryop = enhance(AsStringVisitor.visit_unaryop, 
                                        visitUnaryop)
#============================================================================

# http://www.luismartingil.com/2012/08/14/etree-xml-to-python-dictlists/
def formatXML(parent):
    """
    Recursive operation which returns a tree formated
    as dicts and lists.
    Decision to add a list is to find the 'List' word
    in the actual parent tag.
    
    @param parent Parent element.
    """
    ret = {}
    if list(parent.items()): 
        ret.update(dict([('%s$' % k, v) for k, v in list(parent.items())]))
        
    if parent.text: ret['__content__'] = parent.text
    if ('List' in parent.tag):
        ret['__list__'] = []
        for element in parent:
            if element.tag is not etree.Comment:
                ret['__list__'].append(formatXML(element))
    else:
        for element in parent:
            if element.tag is not etree.Comment:
                ret[element.tag] = formatXML(element)
    return ret
    
def advice(pd, ad):
    """
    Advice to parsed source by parsed XML configuration.
    
    @param pd Parsed XML configuration tree.
    @param ad Parsed Python ASTNG.
    """
    astBuilder = buildmod.ASTNGBuilder()
    
    ctrl = {
        'type': 'replace',
        'return': '__return',
    }
    
    nonAttrs = []
    # Attributes.        
    for k in list(pd.keys()):
        if not k.endswith('$'):                
            nonAttrs.append(k)
            continue        
            
        ctrl[k[:-1]] = pd[k]                
    
    if '__content__' in nonAttrs:
        if pd['__content__'].strip():
            c = pd['__content__']
            if ctrl['type'].lower() == 'replace':
                idx = [i for i, v in enumerate(ad.parent.body) if v == ad][0]
                ad.parent.body[idx] = astBuilder.string_build(c)
            elif ctrl['type'].lower() == 'body':
                ad.body = astBuilder.string_build(c).body
            elif ctrl['type'].lower() == 'after':
                if ad.is_function:
                    if ad.is_method() and (
                       '__builtin__.staticmethod' not in ad.decoratornames()):
                        _self = '%s.' % ad.parent.name
                        
                        # Generate staticmethod decorator node.
                        f = astBuilder.string_build((
                            "class %s(object):\n"
                            "    @staticmethod\n"
                            "    def f(self):\n"
                            "        pass\n"
                        ) % ad.parent.name)[ad.parent.name]['f']
                        if ad.decorators:
                            ad.decorators.nodes.append(f.decorators.nodes[0])
                        else:
                            ad.decorators = f.decorators
                    else: # Normal function.
                        _self = ''
                        
                    args = ad.args.as_string()
                        
                    new = (
                        "def %s(%s): \n"
                        "    %s = %s__%s(%s)\n"
                        "    %s\n"
                    ) % (ad.name, ad.args.as_string(), 
                         ctrl['return'], _self, ad.name, args,
                         '\n    '.join(c.splitlines()))
                    
                    ad.name = '__%s' % ad.name
                    ad.parent.body.append(astBuilder.string_build(new))
                else:
                    ad.body.extend(astBuilder.string_build(c).body)
            elif ctrl['type'].lower() == 'before':
                for i, v in enumerate(astBuilder.string_build(c).body):
                    ad.body.insert(i, v)
            else:
                raise RuntimeError('Not supported advice type: %s' % 
                                   ctrl['type'])
        nonAttrs.remove('__content__')

    for k in nonAttrs:
        if k not in ad:
            raise RuntimeError(k)
                    
        advice(pd[k], ad[k])                   
        
def applyPartialPatch(patchFile, targetFile):
    """
    Apply XML patch configuration to target file.
    
    @param patchFile Patch XML file.
    @param targetFile Target source file.
    """
    astBuilder = buildmod.ASTNGBuilder()

    with open(patchFile, 'rb') as f:
        patchConf = f.read()
        
    patchMd5 = hashlib.md5(patchConf).hexdigest()        
    patchDict = formatXML(etree.fromstring(patchConf))

    if not os.path.exists(targetFile):
        with open(targetFile, 'wb') as f:
            f.writelines([
                "#!/usr/bin/env python\n",
                "# -*- coding: utf-8 -*-\n",
                "# CREATED BY PATCH \n",
            ])
    
    with open(targetFile, 'rb') as f:
        source = f.read()
        
    tmpFile = '_%s' % os.path.split(targetFile)[-1]
    tmpFile = os.path.join(os.path.split(targetFile)[0], tmpFile)
        
    patchedTagIdx = source.rfind('# $Patch: ')
    if patchedTagIdx != -1: # Patched file.
        patchedMd5 = source[patchedTagIdx + len('# $Patch: '):
                            source.rfind(' $')]
        if patchMd5 == patchedMd5:
            # Patch configuration not changed, do not need to patch twice.
            return
        
        # Apply patch to original source, not patched source.
        sourceEncoded = source[source.rfind('$') + 1:]
        source = zlib.decompress(base64.b64decode(sourceEncoded))
        
        # Write back to file, this is because we can only use 
        # astBuilder.file_build. 
        # We don't touch original file because we don't want to break it 
        # even if exception raised.
        try:
            with open(tmpFile, 'wb') as f:
                f.write(source)
                
            # logilab.astng's bug, it will append \n to source but if source is 
            # using \r\n, compile will report SyntaxError.
            #ast = astBuilder.string_build(source)
            ast = astBuilder.file_build(tmpFile)      
        except:
            raise
        finally:
            if os.path.exists(tmpFile):
                os.remove(tmpFile)
    else:
        sourceEncoded = base64.b64encode(
            zlib.compress(source, zlib.Z_BEST_COMPRESSION))
            
        ast = astBuilder.file_build(targetFile)
                    
    advice(patchDict, ast)
    
    # Parse shebang and pep0263 encoding information.
    shebangRE = re.compile(r"(^#!.*)")
    pep0263RE = re.compile(r"(^#.*coding[\:=].*)")

    shebang = None
    pep0263 = None
    
    first2Lines = source.split('\n')[0:2]
    
    shebangMatch = shebangRE.match(first2Lines[0]) if first2Lines else None
    if shebangMatch:
        shebang = shebangMatch.group(1)

    pep0263Match = pep0263RE.match(first2Lines[0]) if first2Lines else None
    if not pep0263Match and first2Lines and len(first2Lines) == 2:
        pep0263Match = pep0263RE.match(first2Lines[1])
        
    if pep0263Match:
        pep0263 = pep0263Match.group(1)
    
    # Write back to target file.
    # We must carefully do not break original file.
    try:
        with open(tmpFile, 'wb') as f:
            if shebang:
                f.write(shebang)
                f.write('\n')
            
            if pep0263:
                f.write(pep0263)
                f.write('\n')            
            
            f.write(ast.as_string())
            f.write('\n')
            f.write('# $Patch: %s $ %s\n' % (patchMd5, sourceEncoded))
    except:
        raise
    else:
        # Overwrite original file.
        shutil.copyfile(tmpFile, targetFile)
        os.remove(tmpFile)
        
    # Check patched file valid or not.
    checkAST = astBuilder.file_build(targetFile)

def patchFile(baseDir, fn, targetDir, quiet=False):
    """
    Patch for file.
    
    @param baseDir Base directory of patch file.
    @param fn Patch file name.
    @param targetDir Target directory.
    @param quiet Quiet run or print out the progress.    
    """
    if fn.endswith('.patch.xml'):
        patchFile = os.path.join(baseDir, fn)
        targetFile = os.path.join(targetDir, fn[:-len('.patch.xml')])
        if not quiet:
            print(colored('Applying', 'yellow', attrs=['bold']), end=' ') 
            print(colored(patchFile, 'white', attrs=['bold']), end=' ') 
            print(colored('to', 'yellow', attrs=['bold']), end=' ')
            print(colored(targetFile, 'red', attrs=['bold']))
            
        applyPartialPatch(patchFile, targetFile)
    else:
        targetFile = os.path.join(targetDir, fn)
        #print 'target', target
        #if not os.path.exists(target):
        #    continue
        if not quiet:
            print('Patch %s' % targetFile)
    
        shutil.copyfile(os.path.join(baseDir, fn), targetFile)
        
def overwrite(base, d, target, copy=False, 
              ignores=('.svn', '.git'), quiet=False):
    """
    Overwrite/patch files to target files recursively.
    
    @param base Base directory to store patch files.
    @param d Directory relative to base.
    @param copy True if make not existed directories.
    @param ignores Ignore directories, files.
    @param quiet Quiet run or print out the progress.    
    """
    foundPackageDir = None
    with cd(base):
        for root, dirs, files in os.walk(d):
            relative = os.path.normpath(root)
            #print 'relative', relative

            sdir = os.path.join(target, relative)

            ignore = False
            for i in ignores:
                if i in sdir.split(os.path.sep):
                    ignore = True
                    break
            if ignore:
                continue

            exists = os.path.exists(sdir)
            if not copy:
                if not exists:
                    # Try import existed package to find its path.
                    oriSdir = None
                    try:
                        module = classForName(
                            relative.replace(os.path.sep, '.'))
                        sdir = os.path.dirname(module.__file__)
                        #print 'sdir', sdir
                        exists = os.path.exists(sdir)
                    except ImportError:
                        if foundPackageDir:
                            oriSdir = sdir
                            sdir = os.path.join(
                                os.path.dirname(foundPackageDir), relative)
                            #print 'sdir2', sdir
                            exists = os.path.exists(sdir)
                    except:
                        pass

                    if not exists:
                        if not quiet:
                            print(colored("Couldn't find installed", 'red'), end=' ')
                            print(colored(oriSdir if oriSdir else sdir, 
                                          'white', attrs=['bold']), end=' ') 
                            print(colored(" NOT PATCHED", 'red', attrs=['bold']))
                            
                        continue
            else:
                if not exists:
                    # Make directory.
                    if not quiet:
                        print('mkdir %s' % sdir)
                    os.makedirs(sdir)
                    exists = os.path.exists(sdir)

            if exists:
                foundPackageDir = sdir

            for f in files:            
                patchFile(root, f, sdir, quiet=quiet)
                #shutil.copyfile(os.path.join(root, f), targetf)

def patch(patchDir, virtualenv=None, special={}, strict=True, requirements=None,
          ignores=('.svn', '.git'), quiet=False):
    """
    Patch python libraries that we patched.

    @param patchDir Base directory to store patch files.
    @param virtualenv Virtualenv path.
    @param special Special hook for packages, a dictionary which key is package
                   name and value is handle function.
    @param strict True if only existed files will be patched.
    @param requirements Patch predicate requirements, for instance: {
                            'django': [ # patch folder name.
                                'django >= 1.5', # package requirement lines.
                            ]
                        }
    @param ignores Ignore directories, files.
    @param quiet Quiet run or print out the progress.
    """
    import pkg_resources
    
    # Please note this import must inside here because it should be imported 
    # after virtualenv activated.    
    from Iuppiter import DistUtil
       
    patchDir = os.path.abspath(patchDir)
    
    if not quiet:
        print()
        print(colored('*' * 20, 'cyan', attrs=['bold']), end=' ')
        print(colored('Update Patches', 'yellow', 'on_red', attrs=['bold']), end=' ')
        print(colored('*' * 20, 'cyan', attrs=['bold']))
        print()
    
    with cd(patchDir):
        for d in os.listdir('.'):
            ignore = False
            for i in ignores:
                if i == d:
                    ignore = True
                    break
            
            if ignore:
                continue

            sitePackagesDir = DistUtil.getSitePackagesDir()
            relative = os.path.normpath(d)
            
            # Package installed using easy_install sometimes resides in .egg 
            # folder.
            try:
                dist = pkg_resources.get_distribution(d)
                if sitePackagesDir != dist.location:
                    sitePackagesDir = dist.location
            except Exception as e:
                print(str(e))
                pass

            rsdir = os.path.join(sitePackagesDir, relative)
                
            if requirements and d in requirements:
                satisfied = True
                for req in requirements[d]:
                    reqObj = pkg_resources.Requirement.parse(req)
                    try:
                        pkg = pkg_resources.get_distribution(reqObj)
                    except (pkg_resources.DistributionNotFound,
                            pkg_resources.VersionConflict):
                        satisfied = False
                        break
                
                if not satisfied:
                    if not quiet:
                        print(colored(
                            'Patch [%s] requires %s, not satisfied.' % (d, req),
                            'red', attrs=['bold']))
                        print()
                    continue

            if not quiet:
                print(colored('=' * 18, 'white', attrs=['bold']), end=' ')
                print(colored('Patching [%s]' % d, 'yellow', attrs=['bold']), end=' ')
                print(colored('=' * 18, 'white', attrs=['bold']))

            if d in special:
                special[d](patchDir, d, sitePackagesDir, rsdir)
            elif not os.path.isdir(d):                 
                patchFile(patchDir, d, sitePackagesDir)                 
                #shutil.copyfile(os.path.join(patchDir, d), rsdir)
            else:
                overwrite(patchDir, d, target=sitePackagesDir, 
                          copy=True if not strict else False,
                          ignores=ignores, quiet=quiet)

            if not quiet:
                print()
    
    return
