百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术文章 > 正文

comfyui自定义节点,生成自己想要的场景

itomcoil 2024-12-23 11:07 19 浏览

我的需求:

我想通过comfyui的节点工作流生成方式,实现一键视频生成文章的功能,这样我就能把自己喜欢的一些视频通过这种方式直接转化成PDF的形式。

实现过程

  1. 第一步:从bilibili网站找到直接喜欢的视频,通过视频链接下载到本地,生成对应的图片。
  2. 第二步:视频里提取音频,通过调用大模型生成文本。
  3. 第三步:通过图片结合文本方式形成PDF

实现技术

comfyui 自定义插件实现功能。

第一步已实现插件代码逻辑

在custom_nodes目录下创建自己的插件 ComfyUI-videoToArticle,如图所示:


进入插件目录,目录及文件如图:


实现第一步的三个节点源码,以下给大家分享。

__init__.py 源码:

import os
import subprocess
import sys
import importlib.util

# 检查并安装依赖的函数
def check_and_install_requirements():
    requirements_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "requirements.txt")
    
    if not os.path.exists(requirements_path):
        print("未找到 requirements.txt 文件")
        return
    
    # 读取 requirements.txt
    with open(requirements_path, 'r', encoding='utf-8') as f:
        requirements = [line.strip() for line in f.readlines() if line.strip()]
    
    # 检查每个依赖
    for requirement in requirements:
        package_name = requirement.split('>=')[0].split('==')[0].strip()
        try:
            importlib.util.find_spec(package_name)
        except ImportError:
            print(f"正在安装依赖: {requirement}")
            try:
                subprocess.check_call([sys.executable, "-m", "pip", "install", requirement])
                print(f"成功安装: {requirement}")
            except subprocess.CalledProcessError as e:
                print(f"安装失败 {requirement}: {str(e)}")

# 在导入时自动检查并安装依赖
check_and_install_requirements()

# 导入节点类
try:
    from .视频获取 import VideoDownloader
    from .视频帧提取 import VideoFrameExtractor
    
    # 注册节点
    NODE_CLASS_MAPPINGS = {
        "VideoDownloader": VideoDownloader,
        "VideoFrameExtractor": VideoFrameExtractor
    }

    NODE_DISPLAY_NAME_MAPPINGS = {
        "VideoDownloader": "B站视频下载器",
        "VideoFrameExtractor": "视频帧提取器"
    }

except ImportError as e:
    print(f"导入节点类时出错: {str(e)}")
    NODE_CLASS_MAPPINGS = {}
    NODE_DISPLAY_NAME_MAPPINGS = {}

__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']

视频获取.py 源码:

import os
import torch
import cv2
from bilibili_api import video, Credential, sync
import aiohttp
import asyncio
import re
import time
import subprocess

class VideoDownloader:
    def __init__(self):
        # 创建下载目录
        self.output_dir = "downloaded_videos"
        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)
            
        # B站请求头
        self.headers = {
            'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
            'Referer': 'https://www.bilibili.com',
            'Accept': '*/*',
            'Origin': 'https://www.bilibili.com',
            'Accept-Encoding': 'gzip, deflate, br',
            'Accept-Language': 'zh-CN,zh;q=0.9',
        }

    CATEGORY = "视频转文章"

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "视频链接": ("STRING", {
                    "default": "", 
                    "multiline": False,
                    "placeholder": "请输入B站视频URL或BV号"
                }),
                "预览帧": ("INT", {
                    "default": 0,
                    "min": 0,
                    "max": 10000,
                    "step": 1,
                    "display": "number"
                }),
            },
            "optional": {
                "SESSDATA": ("STRING", {
                    "default": "",
                    "multiline": False,
                    "placeholder": "输入B站SESSDATA(可选)"
                })
            }
        }

    RETURN_TYPES = ("STRING", "IMAGE")
    RETURN_NAMES = ("视频路径", "预览图像")
    
    FUNCTION = "download_and_preview"

    def extract_bvid(self, url):
        # 从URL中提取BV号
        bv_pattern = r'BV[a-zA-Z0-9]+'
        match = re.search(bv_pattern, url)
        if match:
            return match.group()
        return url

    async def download_bilibili_video(self, url, sessdata=None):
        bvid = self.extract_bvid(url)
        temp_video = None
        temp_audio = None
        
        try:
            # 检查 ffmpeg 是否可用
            try:
                subprocess.run(['ffmpeg', '-version'], capture_output=True, check=True)
            except (subprocess.SubprocessError, FileNotFoundError):
                raise ValueError("未找到 ffmpeg,请先安装 ffmpeg 并确保其在系统路径中")

            credential = None
            if sessdata:
                credential = Credential(sessdata=sessdata)
            v = video.Video(bvid=bvid, credential=credential)
            
            video_info = await v.get_info()
            title = video_info['title']
            
            # 安全的文件名
            safe_title = "".join(x for x in title if x.isalnum() or x in (' ','-','_')).rstrip()
            video_path = os.path.join(self.output_dir, f"{safe_title}.mp4")
            temp_video = os.path.join(self.output_dir, f"{safe_title}_temp_video.m4s")
            temp_audio = os.path.join(self.output_dir, f"{safe_title}_temp_audio.m4s")
            
            # 如果文件已存在,先删除
            for file in [video_path, temp_video, temp_audio]:
                if os.path.exists(file):
                    os.remove(file)
            
            video_url = await v.get_download_url(0)
            video_stream_url = video_url['dash']['video'][0]['baseUrl']
            audio_stream_url = video_url['dash']['audio'][0]['baseUrl']
            
            print(f"开始下载视频: {safe_title}")
            
            # 下载视频流
            async with aiohttp.ClientSession() as session:
                # 下载视频部分
                print("下载视频流...")
                async with session.get(video_stream_url, headers=self.headers) as resp:
                    if resp.status != 200:
                        raise ValueError(f"视频下载失败,状态码:{resp.status}")
                    with open(temp_video, 'wb') as f:
                        async for chunk in resp.content.iter_chunked(1024*1024):
                            f.write(chunk)
                
                # 下载音频部分
                print("下载音频流...")
                async with session.get(audio_stream_url, headers=self.headers) as resp:
                    if resp.status != 200:
                        raise ValueError(f"音频下载失败,状态码:{resp.status}")
                    with open(temp_audio, 'wb') as f:
                        async for chunk in resp.content.iter_chunked(1024*1024):
                            f.write(chunk)
            
            # 检查临时文件是否存在
            if not os.path.exists(temp_video) or not os.path.exists(temp_audio):
                raise ValueError("临时文件下载失败")
                
            print("合并音视频...")
            # 使用绝对路径执行ffmpeg
            ffmpeg_cmd = [
                'ffmpeg',
                '-i', os.path.abspath(temp_video),
                '-i', os.path.abspath(temp_audio),
                '-c', 'copy',
                os.path.abspath(video_path),
                '-y'
            ]
            
            process = subprocess.Popen(
                ffmpeg_cmd,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                text=True
            )
            stdout, stderr = process.communicate()
            
            if process.returncode != 0:
                print(f"FFmpeg错误输出: {stderr}")
                raise ValueError(f"FFmpeg合并失败,返回码: {process.returncode}")
            
            # 检查输出文件
            if not os.path.exists(video_path):
                raise ValueError("合并后的视频文件未生成")
                
            print("清理临时文件...")
            # 清理临时文件
            for file in [temp_video, temp_audio]:
                if os.path.exists(file):
                    os.remove(file)
            
            # 等待文件写入完成
            time.sleep(1)
            
            print("验证视频文件...")
            # 验证文件是否可读
            cap = cv2.VideoCapture(video_path)
            if not cap.isOpened():
                raise ValueError("无法打开合并后的视频文件")
            cap.release()
            
            print("视频处理完成")
            return video_path
            
        except Exception as e:
            # 清理所有临时文件
            if temp_video and os.path.exists(temp_video):
                os.remove(temp_video)
            if temp_audio and os.path.exists(temp_audio):
                os.remove(temp_audio)
            raise ValueError(f"视频处理失败: {str(e)}")

    def download_and_preview(self, 视频链接, 预览帧, SESSDATA=""):
        if not 视频链接:
            raise ValueError("请输入有效的视频URL")

        try:
            video_path = asyncio.run(self.download_bilibili_video(视频链接, SESSDATA))
            
            # 等待文件完全写入
            time.sleep(1)
            
            # 尝试多次打开视频文件
            max_attempts = 3
            for attempt in range(max_attempts):
                cap = cv2.VideoCapture(video_path)
                if cap.isOpened():
                    break
                time.sleep(1)
            
            if not cap.isOpened():
                raise ValueError("无法打开视频文件")
            
            # 获取实际帧数
            total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            if 预览帧 >= total_frames:
                预览帧 = 0
            
            cap.set(cv2.CAP_PROP_POS_FRAMES, 预览帧)
            ret, frame = cap.read()
            
            if not ret:
                raise ValueError(f"无法读取视频帧 {预览帧}")

            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            preview_image = torch.from_numpy(frame_rgb).float() / 255.0
            preview_image = preview_image.unsqueeze(0)
            
            cap.release()

            return (video_path, preview_image)

        except Exception as e:
            raise ValueError(f"下载或处理视频时出错: {str(e)}")

    @classmethod
    def IS_CHANGED(cls, 视频链接, 预览帧, SESSDATA=""):
        return float("nan")

# 节点注册
NODE_CLASS_MAPPINGS = {
    "VideoDownloader": VideoDownloader
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "VideoDownloader": "B站视频下载器"
}

视频帧提取.py 源码

# -*- coding: utf-8 -*-
import os
import cv2
import torch
import numpy as np
from PIL import Image
import sys

class VideoFrameExtractor:
    def __init__(self):
        # 创建输出目录
        self.base_output_dir = "extracted_frames"
        if not os.path.exists(self.base_output_dir):
            os.makedirs(self.base_output_dir)

    CATEGORY = "视频转文章"

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "视频路径": ("STRING", {"forceInput": True}),
                "起始帧": ("INT", {
                    "default": 0,
                    "min": 0,
                    "max": 100000,
                    "step": 1,
                    "display": "number"
                }),
                "提取间隔": ("INT", {
                    "default": 30,
                    "min": 1,
                    "max": 300,
                    "step": 1,
                    "display": "number"
                }),
                "最大提取数": ("INT", {
                    "default": 20,
                    "min": 1,
                    "max": 100,
                    "step": 1,
                    "display": "number"
                }),
            },
            "optional": {
                "结束帧": ("INT", {
                    "default": -1,
                    "min": -1,
                    "max": 100000,
                    "step": 1,
                    "display": "number"
                }),
                "保存帧": ("BOOLEAN", {"default": True}),
                "子目录名": ("STRING", {
                    "default": "",
                    "multiline": False,
                    "placeholder": "可选,留空则使用视频文件名"
                }),
            }
        }

    RETURN_TYPES = ("IMAGE", "STRING")
    RETURN_NAMES = ("帧序列", "帧路径列表")
    
    FUNCTION = "extract_frames"

    def create_output_dir(self, video_path, sub_dir=""):
        try:
            if not sub_dir:
                sub_dir = os.path.splitext(os.path.basename(video_path))[0]
            output_dir = os.path.join(self.base_output_dir, sub_dir)
            original_dir = output_dir
            counter = 1
            while os.path.exists(output_dir):
                output_dir = f"{original_dir}_{counter}"
                counter += 1
            os.makedirs(output_dir)
            return output_dir
        except Exception as e:
            print(f"创建目录时出错: {str(e)}")
            import time
            backup_dir = os.path.join(self.base_output_dir, f"frames_{int(time.time())}")
            os.makedirs(backup_dir, exist_ok=True)
            return backup_dir

    def extract_frames(self, 视频路径, 起始帧, 提取间隔, 最大提取数, 结束帧=-1, 保存帧=True, 子目录名=""):
        if not os.path.exists(视频路径):
            raise ValueError(f"视频文件不存在: {视频路径}")

        cap = None
        try:
            output_dir = self.create_output_dir(视频路径, 子目录名) if 保存帧 else None
            
            cap = cv2.VideoCapture(视频路径)
            if not cap.isOpened():
                raise ValueError("无法打开视频文件")
            
            total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            fps = cap.get(cv2.CAP_PROP_FPS)
            
            if 结束帧 == -1 or 结束帧 >= total_frames:
                结束帧 = total_frames - 1
            
            if 起始帧 < 0 or 起始帧 >= total_frames:
                raise ValueError(f"起始帧超出范围 (0-{total_frames-1})")
            if 结束帧 < 起始帧:
                raise ValueError(f"结束帧必须大于起始帧")
            
            帧范围 = 结束帧 - 起始帧 + 1
            实际间隔 = max(提取间隔, int(帧范围 / 最大提取数))
            帧位置列表 = range(起始帧, 结束帧 + 1, 实际间隔)
            帧位置列表 = list(帧位置列表)[:最大提取数]
            
            frames = []
            frame_paths = []
            
            print(f"开始提取帧,范围:{起始帧}-{结束帧},间隔:{实际间隔},计划提取:{len(帧位置列表)}帧")
            
            for frame_pos in 帧位置列表:
                cap.set(cv2.CAP_PROP_POS_FRAMES, frame_pos)
                ret, frame = cap.read()
                
                if not ret:
                    print(f"警告:无法读取帧位置 {frame_pos}")
                    continue
                
                # 转换颜色空间并确保格式正确
                frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frame_tensor = torch.from_numpy(frame_rgb).float() / 255.0
                
                # 确保维度正确 (H, W, C)
                if len(frame_tensor.shape) == 3:
                    frames.append(frame_tensor)
                    
                    if 保存帧 and output_dir:
                        frame_path = os.path.join(output_dir, f"frame_{frame_pos:06d}.png")
                        Image.fromarray((frame_rgb).astype(np.uint8)).save(frame_path)
                        frame_paths.append(frame_path)
                        print(f"已保存帧 {frame_pos}: {frame_path}")
            
            if not frames:
                raise ValueError("没有成功提取到帧")
            
            # 堆叠所有帧并确保格式正确 (N, H, W, C)
            frames_tensor = torch.stack(frames)
            frame_paths_str = ",".join(frame_paths) if frame_paths else ""
            
            print(f"帧提取完成,共提取{len(frames)}帧")
            if 保存帧:
                print(f"帧已保存到目录: {output_dir}")
            
            return (frames_tensor, frame_paths_str)

        except Exception as e:
            raise ValueError(f"提取帧时出错: {str(e)}")
        finally:
            if cap is not None:
                cap.release()

    @classmethod
    def IS_CHANGED(cls, 视频路径, *args):
        return float("nan")

# 节点注册
NODE_CLASS_MAPPINGS = {
    "VideoFrameExtractor": VideoFrameExtractor
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "VideoFrameExtractor": "视频帧提取器"
} 

注意:python文件的依赖要下载。

第一步实现的效果

相关推荐

Python Qt GUI设计:将UI文件转换Python文件三种妙招(基础篇—2)

在开始本文之前提醒各位朋友,Python记得安装PyQt5库文件,Python语言功能很强,但是Python自带的GUI开发库Tkinter功能很弱,难以开发出专业的GUI。好在Python语言的开放...

Connect 2.0来了,还有Nuke和Maya新集成

ftrackConnect2.0现在可以下载了--重新设计的桌面应用程序,使用户能够将ftrackStudio与创意应用程序集成,发布资产等。这个新版本的发布中还有两个Nuke和Maya新集成,...

Magicgui:不会GUI编程也能轻松构建Python GUI应用

什么是MagicguiMagicgui是一个Python库,它允许开发者仅凭简单的类型注解就能快速构建图形用户界面(GUI)应用程序。这个库基于Napari项目,利用了Python的强大类型系统,使得...

Python入坑系列:桌面GUI开发之Pyside6

阅读本章之后,你可以掌握这些内容:Pyside6的SignalsandSlots、Envents的作用,如何使用?PySide6的Window、DialogsandAlerts、Widgets...

Python入坑系列-一起认识Pyside6 designer可拖拽桌面GUI

通过本文章,你可以了解一下内容:如何安装和使用Pyside6designerdesigner有哪些的特性通过designer如何转成python代码以前以为Pyside6designer需要在下载...

pyside2的基础界面(pyside2显示图片)

今天我们来学习pyside2的基础界面没有安装过pyside2的小伙伴可以看主页代码效果...

Python GUI开发:打包PySide2应用(python 打包pyc)

之前的文章我们介绍了怎么使用PySide2来开发一个简单PythonGUI应用。这次我们来将上次完成的代码打包。我们使用pyinstaller。注意,pyinstaller默认会将所有安装的pack...

使用PySide2做窗体,到底是怎么个事?看这个能不能搞懂

PySide2是Qt框架的Python绑定,允许你使用Python创建功能强大的跨平台GUI应用程序。PySide2的基本使用方法:安装PySide2pipinstallPy...

pycharm中conda解释器无法配置(pycharm安装的解释器不能用)

之前用的好好的pycharm正常配置解释器突然不能用了?可以显示有这个环境然后确认后可以conda正在配置解释器,但是进度条结束后还是不成功!!试过了pycharm重启,pycharm重装,anaco...

Conda使用指南:从基础操作到Llama-Factory大模型微调环境搭建

Conda虚拟环境在Linux下的全面使用指南:从基础操作到Llama-Factory大模型微调环境搭建在当今的AI开发与数据分析领域,conda虚拟环境已成为Linux系统下管理项目依赖的标配工具。...

Python操作系统资源管理与监控(python调用资源管理器)

在现代计算环境中,对操作系统资源的有效管理和监控是确保应用程序性能和系统稳定性的关键。Python凭借其丰富的标准库和第三方扩展,提供了强大的工具来实现这一目标。本文将探讨Python在操作系统资源管...

本地部署开源版Manus+DeepSeek创建自己的AI智能体

1、下载安装Anaconda,设置conda环境变量,并使用conda创建python3.12虚拟环境。2、从OpenManus仓库下载代码,并安装需要的依赖。3、使用Ollama加载本地DeepSe...

一文教会你,搭建AI模型训练与微调环境,包学会的!

一、硬件要求显卡配置:需要Nvidia显卡,至少配备8G显存,且专用显存与共享显存之和需大于20G。二、环境搭建步骤1.设置文件存储路径非系统盘存储:建议将非安装版的环境文件均存放在非系统盘(如E盘...

使用scikit-learn为PyTorch 模型进行超参数网格搜索

scikit-learn是Python中最好的机器学习库,而PyTorch又为我们构建模型提供了方便的操作,能否将它们的优点整合起来呢?在本文中,我们将介绍如何使用scikit-learn中的网格搜...

如何Keras自动编码器给极端罕见事件分类

全文共7940字,预计学习时长30分钟或更长本文将以一家造纸厂的生产为例,介绍如何使用自动编码器构建罕见事件分类器。现实生活中罕见事件的数据集:背景1.什么是极端罕见事件?在罕见事件问题中,数据集是...