"""企业微信消息加解密
基于企业微信官方SDK的Python实现
"""
import base64
import struct
import random
import string
from hashlib import sha1
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad


class WeChatCrypto:
    """企业微信消息加解密"""

    def __init__(self, token: str, encoding_aes_key: str, corp_id: str):
        self.token = token
        self.corp_id = corp_id
        # AES key is Base64 decoded from encoding_aes_key
        self.aes_key = base64.b64decode(encoding_aes_key + '=')

    def verify_signature(self, signature: str, timestamp: str, nonce: str, data: str) -> bool:
        """验证签名"""
        tmp_str = ''.join(sorted([self.token, timestamp, nonce, data]))
        return sha1(tmp_str.encode()).hexdigest() == signature

    def decrypt(self, encrypted_data: str) -> str:
        """解密消息"""
        # Base64 decode
        encrypted_bytes = base64.b64decode(encrypted_data)

        # AES decrypt
        cipher = AES.new(self.aes_key, AES.MODE_CBC, self.aes_key[:16])
        decrypted = cipher.decrypt(encrypted_bytes)

        # Remove padding
        decrypted = unpad(decrypted, AES.block_size)

        # Parse structure: random(16) + msg_len(4) + msg + corp_id
        content = decrypted[16:]  # Skip random
        msg_len = struct.unpack('!I', content[:4])[0]
        msg = content[4:4+msg_len].decode('utf-8')

        return msg

    def encrypt(self, message: str) -> str:
        """加密消息"""
        # Build structure: random(16) + msg_len(4) + msg + corp_id
        random_bytes = ''.join(random.choices(string.ascii_letters + string.digits, k=16)).encode()
        msg_len = struct.pack('!I', len(message))
        msg_bytes = message.encode('utf-8')
        corp_id_bytes = self.corp_id.encode('utf-8')

        content = random_bytes + msg_len + msg_bytes + corp_id_bytes

        # PKCS7 padding
        content = pad(content, AES.block_size)

        # AES encrypt
        cipher = AES.new(self.aes_key, AES.MODE_CBC, self.aes_key[:16])
        encrypted = cipher.encrypt(content)

        # Base64 encode
        return base64.b64encode(encrypted).decode('utf-8')

    def generate_signature(self, timestamp: str, nonce: str, data: str) -> str:
        """生成签名"""
        tmp_str = ''.join(sorted([self.token, timestamp, nonce, data]))
        return sha1(tmp_str.encode()).hexdigest()
