import pikepdf
from pikepdf import Pdf, Rectangle, Stream, Array, Name
import os
import sys
from pathlib import Path
from natsort import natsorted

class PDFMerger:
    """PDF批量合并工具 - 统一为A4尺寸"""
    
    # A4纸张尺寸（单位：points）
    A4_WIDTH = 595.28
    A4_HEIGHT = 841.89
    
    def __init__(self, output_path="merged_output.pdf"):
        self.output_path = output_path
        self.merged_pdf = Pdf.new()
        self.total_pages = 0
        self.processed_files = 0
        
    def normalize_page_to_a4(self, page):
        """
        将单个页面标准化为A4尺寸
        """
        try:
            # 获取原始页面尺寸
            if '/MediaBox' in page:
                mediabox = page.MediaBox
                original_width = float(mediabox[2] - mediabox[0])
                original_height = float(mediabox[3] - mediabox[1])
            else:
                original_width = self.A4_WIDTH
                original_height = self.A4_HEIGHT
            
            # 如果页面已经是A4尺寸，跳过处理
            if abs(original_width - self.A4_WIDTH) < 1 and abs(original_height - self.A4_HEIGHT) < 1:
                return True
            
            # 计算缩放比例
            scale_x = self.A4_WIDTH / original_width
            scale_y = self.A4_HEIGHT / original_height
            
            # 使用较小的缩放比例，确保内容完整显示
            scale = min(scale_x, scale_y)
            
            # 计算居中偏移
            scaled_width = original_width * scale
            scaled_height = original_height * scale
            x_offset = (self.A4_WIDTH - scaled_width) / 2
            y_offset = (self.A4_HEIGHT - scaled_height) / 2
            
            # 处理 Contents
            if '/Contents' in page:
                # 获取所有内容流
                contents = page.Contents
                
                # 如果 Contents 是数组，需要合并
                if isinstance(contents, Array):
                    # 读取所有内容流并合并
                    combined_content = b''
                    for content_obj in contents:
                        try:
                            if hasattr(content_obj, 'read_bytes'):
                                combined_content += content_obj.read_bytes() + b'\n'
                        except:
                            pass
                    
                    # 创建新的内容流
                    if combined_content:
                        transform = f"q {scale} 0 0 {scale} {x_offset} {y_offset} cm\n"
                        new_content = transform.encode() + combined_content + b"Q\n"
                        
                        # 创建新的Stream对象
                        new_stream = Stream(self.merged_pdf, new_content)
                        page.Contents = new_stream
                
                # 如果 Contents 是单个流
                else:
                    try:
                        original_content = contents.read_bytes()
                        transform = f"q {scale} 0 0 {scale} {x_offset} {y_offset} cm\n"
                        new_content = transform.encode() + original_content + b"Q\n"
                        
                        # 写入新内容
                        new_stream = Stream(self.merged_pdf, new_content)
                        page.Contents = new_stream
                    except Exception as e:
                        print(f"    警告：内容流处理失败 - {str(e)}")
            
            # 设置新的MediaBox为A4尺寸
            page.MediaBox = Rectangle(0, 0, self.A4_WIDTH, self.A4_HEIGHT)
            
            # 同时设置 CropBox, BleedBox, TrimBox, ArtBox（如果存在）
            if '/CropBox' in page:
                page.CropBox = Rectangle(0, 0, self.A4_WIDTH, self.A4_HEIGHT)
            if '/BleedBox' in page:
                page.BleedBox = Rectangle(0, 0, self.A4_WIDTH, self.A4_HEIGHT)
            if '/TrimBox' in page:
                page.TrimBox = Rectangle(0, 0, self.A4_WIDTH, self.A4_HEIGHT)
            if '/ArtBox' in page:
                page.ArtBox = Rectangle(0, 0, self.A4_WIDTH, self.A4_HEIGHT)
            
            return True
            
        except Exception as e:
            print(f"    警告：页面标准化失败 - {str(e)}")
            # 即使失败也设置MediaBox，确保页面大小统一
            try:
                page.MediaBox = Rectangle(0, 0, self.A4_WIDTH, self.A4_HEIGHT)
            except:
                pass
            return False
    
    def add_pdf_file(self, pdf_path):
        """
        添加一个PDF文件到合并列表
        """
        try:
            print(f"\n处理文件: {os.path.basename(pdf_path)}")
            
            # 打开PDF文件
            src_pdf = Pdf.open(pdf_path, allow_overwriting_input=True)
            page_count = len(src_pdf.pages)
            
            print(f"  页数: {page_count}")
            
            success_count = 0
            
            # 处理每一页
            for i, page in enumerate(src_pdf.pages):
                try:
                    # 先添加页面到合并PDF
                    self.merged_pdf.pages.append(page)
                    
                    # 获取刚添加的页面并标准化为A4尺寸
                    new_page = self.merged_pdf.pages[-1]
                    if self.normalize_page_to_a4(new_page):
                        success_count += 1
                    
                    # 显示进度
                    if (i + 1) % 20 == 0 or (i + 1) == page_count:
                        print(f"  进度: {i + 1}/{page_count} 页 (成功标准化: {success_count})")
                
                except Exception as e:
                    print(f"  警告：第 {i+1} 页处理失败 - {str(e)}")
                    continue
            
            src_pdf.close()
            
            self.total_pages += page_count
            self.processed_files += 1
            print(f"  ✓ 完成，已添加 {page_count} 页 (成功标准化: {success_count}/{page_count})")
            
            return True
            
        except Exception as e:
            print(f"  ✗ 错误: {str(e)}")
            return False
    
    def merge_directory(self, directory_path, pattern="*.pdf", recursive=False, exclude_output=True):
        """
        合并目录中的所有PDF文件
        
        参数:
            directory_path: 目录路径
            pattern: 文件匹配模式（默认 *.pdf）
            recursive: 是否递归搜索子目录
            exclude_output: 是否排除输出文件本身
        """
        print("=" * 70)
        print("PDF批量合并工具 - 统一为A4尺寸")
        print("=" * 70)
        
        # 获取所有PDF文件
        path = Path(directory_path)
        
        if recursive:
            pdf_files = natsorted(path.rglob(pattern))
        else:
            pdf_files = natsorted(path.glob(pattern))
        
        # 排除输出文件本身
        if exclude_output:
            output_name = Path(self.output_path).name
            pdf_files = [f for f in pdf_files if f.name != output_name]
        
        if not pdf_files:
            print(f"\n✗ 在目录 {directory_path} 中未找到PDF文件")
            return False
        
        print(f"\n找到 {len(pdf_files)} 个PDF文件:")
        for i, pdf_file in enumerate(pdf_files, 1):
            file_size = pdf_file.stat().st_size / 1024  # KB
            print(f"  {i}. {pdf_file.name} ({file_size:.1f} KB)")
        
        # 依次处理每个文件
        print(f"\n开始合并处理...")
        
        for pdf_file in pdf_files:
            self.add_pdf_file(str(pdf_file))
        
        # 保存结果
        return self.save()
    
    def merge_file_list(self, file_list):
        """
        合并指定的PDF文件列表
        
        参数:
            file_list: PDF文件路径列表
        """
        print("=" * 70)
        print("PDF批量合并工具 - 统一为A4尺寸")
        print("=" * 70)
        
        print(f"\n准备合并 {len(file_list)} 个文件:")
        for i, pdf_file in enumerate(file_list, 1):
            print(f"  {i}. {os.path.basename(pdf_file)}")
        
        print(f"\n开始合并处理...")
        
        for pdf_file in file_list:
            if os.path.exists(pdf_file):
                self.add_pdf_file(pdf_file)
            else:
                print(f"\n✗ 文件不存在: {pdf_file}")
        
        # 保存结果
        return self.save()
    
    def save(self):
        """
        保存合并后的PDF
        """
        try:
            if self.total_pages == 0:
                print("\n✗ 没有可保存的页面")
                return False
            
            print(f"\n{'=' * 70}")
            print("保存合并结果...")
            
            self.merged_pdf.save(self.output_path)
            self.merged_pdf.close()
            
            # 获取输出文件大小
            file_size = os.path.getsize(self.output_path) / (1024 * 1024)
            
            print(f"\n✓ 合并完成！")
            print(f"  处理文件数: {self.processed_files}")
            print(f"  总页数: {self.total_pages}")
            print(f"  输出文件: {self.output_path}")
            print(f"  文件大小: {file_size:.2f} MB")
            print(f"  页面尺寸: A4 ({self.A4_WIDTH} x {self.A4_HEIGHT} points)")
            print("=" * 70)
            
            return True
            
        except Exception as e:
            print(f"\n✗ 保存失败: {str(e)}")
            import traceback
            traceback.print_exc()
            return False


def main():
    """
    主函数 - 支持多种使用方式
    """
    print("\n")
    
    # ========== 使用方式 1: 合并指定目录中的所有PDF ==========
    # 使用方法：将所有要合并的PDF放在同一个文件夹中
    
    merger = PDFMerger(output_path="合并输出.pdf")
    
    # 选项A: 合并当前目录下的所有PDF（不包括子目录）
    merger.merge_directory(".", pattern="*.pdf", recursive=False)
    
    # 选项B: 合并指定目录（包括子目录）
    # merger.merge_directory("./pdf文件夹", pattern="*.pdf", recursive=True)
    
    
    # ========== 使用方式 2: 合并指定的文件列表 ==========
    # 适合需要按特定顺序合并的情况
    
    # merger = PDFMerger(output_path="合并输出.pdf")
    # file_list = [
    #     "文件1.pdf",
    #     "文件2.pdf",
    #     "文件3.pdf",
    #     "子目录/文件4.pdf"
    # ]
    # merger.merge_file_list(file_list)
    
    
    # ========== 使用方式 3: 命令行参数 ==========
    # python script.py 输入目录 输出文件.pdf
    
    # if len(sys.argv) >= 2:
    #     input_dir = sys.argv[1]
    #     output_file = sys.argv[2] if len(sys.argv) >= 3 else "merged_output.pdf"
    #     
    #     merger = PDFMerger(output_path=output_file)
    #     merger.merge_directory(input_dir, recursive=True)


if __name__ == "__main__":
    main()