#!/usr/bin/env python
# -*- coding: utf-8 -*-
# $Id: PatchWAT.py 14716 2024-07-18 11:19:22Z Tim $
#
# Copyright (c) 2024 Nuwa Information Co., Ltd, All Rights Reserved.
#
# Licensed under the Proprietary License,
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at our web site.
#
# See the License for the specific language governing permissions and
# limitations under the License.
#
# $Author: Tim $
# $Date: 2024-07-18 19:19:22 +0800 (週四, 18 七月 2024) $
# $Revision: 14716 $

import re
import os
import logging
from typing import Tuple, Optional

logger = logging.getLogger(__name__)

# Define the code to be added
FILE_READ_CALLBACK_CODE = '''
    ;; Call responseReadCallback
    call $responseReadCallback
    ;; Drop the results from f as they are not used
    drop
    ;; Call the original WASI fd_read function
    local.get 0
    local.get 1
    local.get 2
    local.get 3
'''

# Define the code to be added for path_open functions
FILE_OPEN_CALLBACK_CODE = '''
   (local $result i32) ;; Define a local variable to store the result
    ;; Call the original WASI path_open function
    local.get 0
    local.get 1
    local.get 2
    local.get 2
    call $strlen
    local.get 3
    local.get 4
    local.get 5
    local.get 6
    local.get 7
    call $__imported_wasi_snapshot_preview1_path_open
    ;; Store the result of path_open in a local variable
    local.tee $result
    i32.const 65535
    i32.and
    local.set $result ;; Store the result of the and operation in a local variable

    ;; Call responseOpenCallback with its required parameters
    local.get 0
    local.get 1
    local.get 2
    local.get 3
    local.get 4
    local.get 5
    local.get 6
    local.get 7
    call $responseOpenCallback
    drop ;; Drop the result of responseOpenCallback as it's not needed

    ;; Return the result of the original WASI path_open call
    local.get $result
'''

# Define the code to be added for fd_write functions
FILE_WRITE_CALLBACK_CODE = '''
   (local $result i32) ;; Define a local variable to store the result
    ;; Call the original WASI fd_write function
    local.get 0
    local.get 1
    local.get 2
    local.get 3
    call $__imported_wasi_snapshot_preview1_fd_write
    ;; Store the result of fd_write in a local variable
    local.tee $result
    i32.const 65535
    i32.and
    local.set $result ;; Store the result of the and operation in a local variable

    ;; Call fileWriteCallback with its required parameters
    local.get 0
    local.get 1
    local.get 2
    local.get 3
    call $fileWriteCallback
    drop ;; Drop the result of fileWriteCallback as it's not needed

    ;; Return the result of the original WASI fd_write call
    local.get $result
'''

# Function to modify the WAT content
def addResponseReadCallbackCode(content: str) -> str:
    # Regular expression to match the $__wasi_fd_read function
    funcPattern = re.compile(
        r'(\(func \$__wasi_fd_read \(type \d+\) \(param i32 i32 i32 i32\) \(result i32\)\s*'
        r'(.*?)\s*'
        r'call \$__imported_wasi_snapshot_preview1_fd_read\s*'
        r'i32\.const \d+\s*'
        r'i32\.and\s*\))',
        re.DOTALL
    )

    # Function to add the new code snippet
    def addCode(match):
        originalBody = match.group(2)
        newBody = f'''
    {originalBody}
    {FILE_READ_CALLBACK_CODE}
    '''
        return match.group(1).replace(originalBody, newBody.strip())

    # Modify the content
    newContent = re.sub(funcPattern, addCode, content)

    return newContent

# Function to modify the WAT content for path_open functions
def addResponseOpenCallbackCode(content: str) -> str:
    # Regular expression to match the $__wasi_path_open function
    funcPattern = re.compile(
        r'(\(func \$__wasi_path_open \(type \d+\) \(param i32 i32 i32 i32 i64 i64 i32 i32\) \(result i32\)\s*'
        r'(.*?)\s*'
        r'call \$__imported_wasi_snapshot_preview1_path_open\s*'
        r'i32\.const \d+\s*'
        r'i32\.and\s*\))',
        re.DOTALL
    )

    # Function to add the new code snippet
    def addCode(match):
        originalBody = match.group(2)
        newBody = f'''
    {FILE_OPEN_CALLBACK_CODE}
    '''
        # Just return the new body content without the matched old call and constant part
        return match.group(1).split(match.group(2))[0] + FILE_OPEN_CALLBACK_CODE.strip() + '\n)'

    # Modify the content
    newContent = re.sub(funcPattern, addCode, content)

    return newContent

# Function to modify the WAT content for fd_write functions
def addFileWriteCallbackCode(content: str) -> str:
    # Regular expression to match the $__wasi_fd_write function
    funcPattern = re.compile(
        r'(\(func \$__wasi_fd_write \(type \d+\) \(param i32 i32 i32 i32\) \(result i32\)\s*'
        r'(.*?)\s*'
        r'call \$__imported_wasi_snapshot_preview1_fd_write\s*'
        r'i32\.const \d+\s*'
        r'i32\.and\s*\))',
        re.DOTALL
    )

    # Function to add the new code snippet
    def addCode(match):
        originalBody = match.group(2)
        newBody = f'''
    {FILE_WRITE_CALLBACK_CODE}
    '''
        # Just return the new body content without the matched old call and constant part
        return match.group(1).split(match.group(2))[0] + FILE_WRITE_CALLBACK_CODE.strip() + '\n)'

    # Modify the content
    newContent = re.sub(funcPattern, addCode, content)

    return newContent

def insertFunctions(originalWat, functionsToAdd):
    """Inserts new functions into the WAT content."""
    lines = originalWat.split('\n')
    for i, line in enumerate(lines):
        if line.strip() == '(module':
            insertIndex = i + 1
            break
    newWat = '\n'.join(lines[:insertIndex] + functionsToAdd + lines[insertIndex:])
    return newWat

def patchWatCallbacks(watPath: str, newWatPath: str=None) -> Tuple[bool, Optional[str]]:
    """
        add callback functions to the original wat file for making proxy http request
    """

    if newWatPath is None:
        # default same directory
        newWatPath = os.path.join(
            os.path.dirname(watPath), "generated.wat"
        )

    with open(watPath, 'r', encoding="utf8") as file:
        originalWat = file.read()

    newFunctions = [
        '  (func $responseReadCallback (import "" "responseReadCallback") (param i32 i32 i32 i32) (result i32))',
        '  (func $responseOpenCallback (import "" "responseOpenCallback") (param i32 i32 i32 i32 i64 i64 i32 i32) (result i32))',
        '  (func $fileWriteCallback (import "" "fileWriteCallback") (param i32 i32 i32 i32) (result i32))'
    ]

    try:
        newWatContent = insertFunctions(originalWat, newFunctions)
        newWatContent = addResponseReadCallbackCode(newWatContent)
        newWatContent = addResponseOpenCallbackCode(newWatContent)
        newWatContent = addFileWriteCallbackCode(newWatContent)
    except Exception as e:
        logger.exception(str(e))
        return False, None

    with open(newWatPath, 'w', encoding="utf8") as file:
        file.write(newWatContent)

    logger.debug("New WAT file created successfully.")
    return True, newWatPath