#!/usr/bin/env python
# -*- coding: utf-8 -*-
# $Id: RunWASM.py 14896 2024-09-11 04:21:02Z Tim $
#
# Copyright (c) 2024 Nuwa Information Co., Ltd, All Rights Reserved.
#
# Licensed under the Proprietary License,
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at our web site.
#
# See the License for the specific language governing permissions and
# limitations under the License.
#
# $Author: Tim $
# $Date: 2024-09-11 12:21:02 +0800 (週三, 11 九月 2024) $
# $Revision: 14896 $

import sys
import os
import json
import logging

from typing import List
from wasmtime import Engine, Store, Module, Config, Linker, WasiConfig, FuncType, ValType
from wasmtime import ExitTrap, Trap, WasmtimeError
from .APIClient import FileAPIClient
from .DataTypes import GETConfig, POSTConfig, WasmResultStatus
from .utils import isUrlAllowed
from .utils.Handlers import FileLimitHandler
from .utils.Exceptions import WorkspaceFileCountLimitExceededException, WorkspaceSizeLimitExceededException

# conda install -c conda-forge libpython-static
# https://github.com/bytecodealliance/wasmtime-py

logger = logging.getLogger(__name__)

def runWASM(
    workspaceDir: str,
    wasmFilePath: str,
    fuel: int = 500_000_000,
    memorySize: int = 32 * 1024 * 1024, # 32 MB (2 次方) # 測一下多少夠
    fileSizeLimit: int = 1024 * 1024 * 10,
    fileCountLimit: int = 1000,
    allowedURLs: List[str] = ["*"],
    debug: bool = False
) -> dict:

    """
        run the specified wasm file using wasmtime
        workspaceDir: where the wasm file workspace is mounted (abs path)
        wasmFilePath: where the wasm file is (abs path)
    """

    cfg = Config()
    # necessary if we want to interrupt execution after some amount of instructions executed
    cfg.consume_fuel = True
    cfg.cache = True
    cfg.debug_info = debug

    engine = Engine(cfg)

    # Load and compile our two modules
    linking = Module.from_file(engine, wasmFilePath) #"access_a_simple.wat"

    # Set up our linker which is going to be linking modules together. We
    # want our linker to have wasi available, so we set that up here as well.
    linker = Linker(engine)
    linker.define_wasi()

    # Create a `Store` to hold instances, and configure wasi state
    store = Store(engine)

    wasi = WasiConfig()
    wasi.inherit_stdin()
    wasi.argv = sys.argv
    wasi.preopen_dir(workspaceDir, "/")
    wasi.stdout_file = os.path.join(workspaceDir, "logs", 'out.log') # host path
    wasi.stderr_file = os.path.join(workspaceDir, "logs", 'err.log')

    store.set_fuel(fuel)
    store.set_wasi(wasi)

    # 建議設置量: server memory (GB) / (預計同時跑的 wasm module)
    # 假設 server 16GB, 有十個同時跑的 wasm module => memory_size = 512 * 1024 * 1024  # 512 MB
    store.set_limits(memory_size=memorySize)

    tempRecord = {
        "currentRequestConfigPath": None,
        "currentResponseFd": None
    }

    if debug:
        # Create a logger
        wasmLogger = logging.getLogger('wasm_debug')
        wasmLogger.setLevel(logging.DEBUG)  # Set the logging level (e.g., DEBUG, INFO, WARNING, ERROR, CRITICAL)

        # Create a file handler to write to out.log with append mode
        # same file as wasi.stdout_file (it has no getter)
        file_handler = logging.FileHandler(
            filename=os.path.join(workspaceDir, "logs", 'debug.log'),
            mode="w",
            encoding="utf8"
        )
        file_handler.setLevel(logging.DEBUG)  # Set the logging level for the file handler

        # Create a formatter and set it for the file handler
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        file_handler.setFormatter(formatter)

        # Add the file handler to the logger
        wasmLogger.addHandler(file_handler)

    # NOTE: when users try to read response.txt
    fileReadFuncType = FuncType([ValType.i32(), ValType.i32(), ValType.i32(), ValType.i32()], [ValType.i32()])
    def responseReadCallback(fd, iovs, iovs_len, nread_ptr):
        """
            fd: The file descriptor of the file to read from.
            iovs: A pointer to an array of __wasi_iovec_t structures describing the buffers where the data will be stored.
            iovs_len: The number of vectors (__wasi_iovec_t) in the iovs array.
            nread: A pointer to store the number of bytes read.
        """
        currentResponseFd = tempRecord["currentResponseFd"]
        currentRequestConfigPath = tempRecord["currentRequestConfigPath"]

        if fd == currentResponseFd and currentRequestConfigPath is not None:
            # if the file being read is response.txt and the request config has been writen
            nread = memory.data_ptr(store)[nread_ptr]
            if nread == 0:
                try:
                    requestFileName: str = os.path.basename(currentRequestConfigPath)
                    # read config
                    with open(f"{workspaceDir}/{currentRequestConfigPath}", "r", encoding="utf8") as f: # request.txt
                        requestConfig = json.loads(f.read())

                        # check url against allowed urls
                        url = requestConfig["url"]

                        if isUrlAllowed(url, allowedURLs):
                            if debug:
                                wasmLogger.info(f"making {requestFileName.split(".")[0].upper()} request to {url}")

                            if requestFileName == "get.txt":
                                config = GETConfig(
                                    url=url,
                                    headers=requestConfig.get("headers"),
                                    params=requestConfig.get("params"),
                                    timeout=requestConfig.get("timeout")
                                )
                                responseText = FileAPIClient.sendGETRequest(config)

                            if requestFileName == "post.txt":
                                config = POSTConfig(
                                    url=url,
                                    headers=requestConfig.get("headers"),
                                    timeout=requestConfig.get("timeout"),
                                    data=requestConfig.get("data"),
                                    json=requestConfig.get("json"),
                                )
                                responseText = FileAPIClient.sendPostRequest(config)
                        else:
                            responseText = json.dumps(
                                {
                                    "status_code": 400,
                                    "message": "this request url is not allowed"
                                }
                            )

                        if debug:
                            wasmLogger.info(f"response from {url}: {responseText}")

                except Exception as e:
                    logger.warning(str(e), exc_info=True)
                    return fd

                # write the response result back to response.txt
                with open(f"{workspaceDir}/requests/response.txt", "w", encoding="utf8") as f:
                    f.write(responseText)
                    f.flush()
        return fd

    # (param i32 i32 i32 i32 i64 i64 i32 i32) (result i32)
    fileOpenFuncType = FuncType([ValType.i32(), ValType.i32(), ValType.i32(), ValType.i32(), ValType.i64(), ValType.i64(), ValType.i32(), ValType.i32()], [ValType.i32()])

    # NOTE: when users try to open response.txt
    def responseOpenCallback(dirflags, pathPointer, pathLen, oFlags, fsRightsBase, fsRightsInheriting, fsFlags, fdPointer):
        """
            dirfd: The file descriptor representing the directory that the file is located in.
            dirflags: Flags specifying how the path will be resolved.
            path: A wasm pointer to a null-terminated string containing the path of the file or directory to open.
            path_len: The length of the path string.
            o_flags: Flags specifying how the file will be opened.
            fs_rights_base: The rights of the created file descriptor.
            fs_rights_inheriting: The rights of file descriptors derived from the created file descriptor.
            fs_flags: The flags of the file descriptor.
            fd: A wasm pointer to a WasiFd variable where the new file descriptor will be stored.
        """

        i = 0
        strBytes = []
        while True:
            b = memory.data_ptr(store)[pathLen + i]
            if b == 0:
                break

            strBytes.append(chr(b))
            i += 1

        path = ''.join(strBytes)
        fd = memory.data_ptr(store)[fdPointer]

        fileName = os.path.basename(path)
        # record the current response file being opened
        if fileName == "response.txt":
            tempRecord["currentResponseFd"] = fd
        # record the request file name being opened
        if fileName in ["get.txt", "post.txt"]:
            tempRecord["currentRequestConfigPath"] = path

        return 0

    # Initialize FileLimitHandler
    fileHandler = FileLimitHandler(directory=workspaceDir, maxSize=fileSizeLimit, maxFiles=fileCountLimit)  # Example: 10 MB max size
    # Define a function type for file writing
    fileWriteFuncType = FuncType([ValType.i32(), ValType.i32(), ValType.i32(), ValType.i32()], [ValType.i32()])

    def fileWriteCallback(fd, iovs, iovs_len, nwritten):
        # Access the WebAssembly memory
        memory_data = memory.data_ptr(store)

        # Initialize an empty bytearray to hold the data chunks
        data_chunks = bytearray()

        # Read the iovs array and extract the data
        for i in range(iovs_len):
            # Each iov entry is composed of a base address and a length, each 4 bytes (32-bit)
            iov_base_offset = iovs + i * 8
            iov_len_offset = iov_base_offset + 4

            # Get the base address and length of the iov entry
            iov_base = int.from_bytes(memory_data[iov_base_offset : iov_base_offset + 4], 'little')
            iov_len = int.from_bytes(memory_data[iov_len_offset : iov_len_offset + 4], 'little')

            # Append the data chunk from the memory
            data_chunks.extend(memory_data[iov_base : iov_base + iov_len])

        # Convert the byte data to a string
        data = data_chunks.decode('utf-8')

        # check the data using the FileLimitHandler
        fileHandler.check(data)

        # Store the number of bytes written in the nwritten parameter
        # memory_data[nwritten:nwritten + 4] = len(data_chunks).to_bytes(4, 'little')

        return 0  # Return 0 on success, -1 on failure

    # Instantiate our first module which only uses WASI, then register that
    # instance with the linker since the next linking will use it.
    linker.define_func("", "responseReadCallback", fileReadFuncType, responseReadCallback)
    linker.define_func("", "responseOpenCallback", fileOpenFuncType, responseOpenCallback)
    linker.define_func("", "fileWriteCallback", fileWriteFuncType, fileWriteCallback)

    linkingInst = linker.instantiate(store, linking)
    exports = linkingInst.exports(store)
    memory = exports["memory"]

    try:
        status = WasmResultStatus.SUCCESS
        errorMsg = None
        _ = exports["_start"](store)
    except ExitTrap:
        pass
    except (Trap, WasmtimeError, WorkspaceSizeLimitExceededException, WorkspaceFileCountLimitExceededException) as e:
        logger.exception(str(e))
        status = WasmResultStatus.ERROR
        errorMsg = str(e)
    finally:
        if debug:
            wasmLogger.info(f"{store.get_fuel()=}")
            wasmLogger.info(f"{fuel=}")
            wasmLogger.info(f"{(store.get_fuel()/fuel)=}")

        # Explicitly clear/store memory
        del memory

        # Explicitly clear/store exports and linkingInst
        del exports
        del linkingInst

        # Clear store to release WASM memory
        store = None

        # Trigger garbage collection to free memory
        import gc
        gc.collect()

    return {
        "status": status,
        "message": errorMsg
    }