331 lines
12 KiB
Python
331 lines
12 KiB
Python
"""
|
|
Main module that handles processing of YouTube transcripts and connecting to the AI service.
|
|
Each user session has its own output stream and thread to handle the asynchronous AI response.
|
|
"""
|
|
|
|
import re
|
|
import threading
|
|
import asyncio
|
|
from asyncio import sleep
|
|
from datetime import datetime
|
|
import pytz
|
|
import os
|
|
import logging
|
|
import uuid
|
|
|
|
# Youtube Transcript imports
|
|
import youtube_transcript_api._errors
|
|
from youtube_transcript_api import YouTubeTranscriptApi
|
|
from youtube_transcript_api.formatters import TextFormatter
|
|
|
|
# OpenAI API imports
|
|
from openai import AssistantEventHandler
|
|
from openai import OpenAI
|
|
|
|
from dotenv import load_dotenv
|
|
load_dotenv()
|
|
|
|
# Global dict for per-user session streams.
|
|
user_streams = {}
|
|
# Lock to ensure thread-safe operations on shared memory.
|
|
stream_lock = threading.Lock()
|
|
|
|
# For running async code in non-async functions.
|
|
awaiter = asyncio.run
|
|
|
|
# Configure logging
|
|
try:
|
|
logging.basicConfig(
|
|
filename='./logs/main.log',
|
|
level=logging.INFO,
|
|
format='%(asctime)s %(levelname)s: %(message)s',
|
|
datefmt='%Y-%m-%d %H:%M:%S'
|
|
)
|
|
except FileNotFoundError as e:
|
|
with open("./logs/main.log", "x"):
|
|
pass
|
|
logging.basicConfig(
|
|
filename='./logs/main.log',
|
|
level=logging.INFO,
|
|
format='%(asctime)s %(levelname)s: %(message)s',
|
|
datefmt='%Y-%m-%d %H:%M:%S'
|
|
)
|
|
logging.info(f"No main.log file was found ({e}), so one was created.")
|
|
|
|
class StreamOutput:
|
|
"""
|
|
Class to encapsulate a session's streaming output.
|
|
|
|
Attributes:
|
|
delta (str): Last delta update.
|
|
response (str): Cumulative response from the AI.
|
|
done (bool): Flag indicating if streaming is complete.
|
|
buffer (list): List of output delta strings pending streaming.
|
|
"""
|
|
def __init__(self):
|
|
self.delta: str = ""
|
|
self.response: str = ""
|
|
self.done: bool = False
|
|
self.buffer: list = []
|
|
|
|
def reset(self):
|
|
"""
|
|
Reset the stream output to its initial state.
|
|
"""
|
|
self.delta = ""
|
|
self.response = ""
|
|
self.done = False
|
|
self.buffer = []
|
|
|
|
def send_delta(self, delta):
|
|
"""
|
|
Process a new delta string. This method is a synchronous wrapper that calls the async
|
|
method process_delta.
|
|
|
|
Args:
|
|
delta (str): The delta string to process.
|
|
"""
|
|
awaiter(self.process_delta(delta))
|
|
|
|
async def process_delta(self, delta):
|
|
"""
|
|
Process a new delta chunk asynchronously to update buffering.
|
|
|
|
Args:
|
|
delta (str): The delta portion of the response.
|
|
"""
|
|
self.delta = delta
|
|
self.response += delta
|
|
|
|
def get_index(lst):
|
|
return 0 if not lst else len(lst) - 1
|
|
|
|
if self.buffer:
|
|
try:
|
|
if self.delta != self.buffer[get_index(self.buffer)]:
|
|
self.buffer.append(delta)
|
|
except IndexError as index_error:
|
|
logging.error(f"Caught IndexError: {str(index_error)}")
|
|
self.buffer.append(delta)
|
|
else:
|
|
self.buffer.append(delta)
|
|
return
|
|
|
|
# OpenAI Client configuration
|
|
client = OpenAI(
|
|
organization='org-7ANUFsqOVIXLLNju8Rvmxu3h',
|
|
project="proj_NGz8Kux8CSka7DRJucAlDCz6",
|
|
api_key=os.getenv("OPENAI_API_KEY")
|
|
)
|
|
|
|
asst_screw_bardo_id = "asst_JGFaX6uOIotqy5mIJnu3Yyp7" # Assistant ID for processing
|
|
|
|
class EventHandler(AssistantEventHandler):
|
|
"""
|
|
Event handler for processing OpenAI assistant events.
|
|
|
|
Attributes:
|
|
output_stream (StreamOutput): The output stream to write updates to.
|
|
"""
|
|
def __init__(self, output_stream: StreamOutput):
|
|
"""
|
|
Initialize the event handler with a specific output stream.
|
|
|
|
Args:
|
|
output_stream (StreamOutput): The session specific stream output instance.
|
|
"""
|
|
super().__init__()
|
|
self.output_stream = output_stream
|
|
|
|
def on_text_created(self, text) -> None:
|
|
"""
|
|
Event triggered when text is first created.
|
|
|
|
Args:
|
|
text (str): The initial response text.
|
|
"""
|
|
self.output_stream.send_delta("Response Received:\n\nScrew-Bardo:\n\n")
|
|
logging.info("Text created event handled.")
|
|
|
|
def on_text_delta(self, delta, snapshot):
|
|
"""
|
|
Event triggered when a new text delta is available.
|
|
|
|
Args:
|
|
delta (Any): Object that contains the new delta information.
|
|
snapshot (Any): A snapshot of the current output (if applicable).
|
|
"""
|
|
self.output_stream.send_delta(delta.value)
|
|
logging.debug(f"Text delta received: {delta.value}")
|
|
|
|
def on_tool_call_created(self, tool_call):
|
|
"""
|
|
Handle the case when the assistant attempts to call a tool.
|
|
Raises an exception as this behavior is unexpected.
|
|
|
|
Args:
|
|
tool_call (Any): The tool call info.
|
|
|
|
Raises:
|
|
Exception: Always, since tool calls are not allowed.
|
|
"""
|
|
error_msg = "Assistant shouldn't be calling tools."
|
|
logging.error(error_msg)
|
|
raise Exception(error_msg)
|
|
|
|
def create_and_stream(transcript, session_id):
|
|
"""
|
|
Create a new thread that runs the OpenAI stream for a given session and transcript.
|
|
|
|
Args:
|
|
transcript (str): The transcript from the YouTube video.
|
|
session_id (str): The unique session identifier.
|
|
"""
|
|
logging.info(f"Starting OpenAI stream thread for session {session_id}.")
|
|
event_handler = EventHandler(user_streams[session_id]['output_stream'])
|
|
try:
|
|
with client.beta.threads.create_and_run_stream(
|
|
assistant_id=asst_screw_bardo_id,
|
|
thread={
|
|
"messages": [{"role": "user", "content": transcript}]
|
|
},
|
|
event_handler=event_handler
|
|
) as stream:
|
|
stream.until_done()
|
|
with stream_lock:
|
|
user_streams[session_id]['output_stream'].done = True
|
|
logging.info(f"OpenAI stream completed for session {session_id}.")
|
|
except Exception as e:
|
|
logging.exception(f"Exception occurred during create_and_stream for session {session_id}.")
|
|
|
|
def yoink(session_id):
|
|
"""
|
|
Generator that yields streaming output for a session.
|
|
|
|
This function starts the AI response thread, then continuously yields data from the session's output buffer
|
|
until the response is marked as done.
|
|
|
|
Args:
|
|
session_id (str): The unique session identifier.
|
|
|
|
Yields:
|
|
bytes: Chunks of the AI generated response.
|
|
"""
|
|
logging.info(f"Starting stream for session {session_id}...")
|
|
with stream_lock:
|
|
user_data = user_streams.get(session_id)
|
|
if not user_data:
|
|
logging.critical(f"User data not found for session id {session_id}?")
|
|
return
|
|
output_stream: StreamOutput = user_data.get('output_stream')
|
|
thread: threading.Thread = user_data.get('thread')
|
|
thread.start()
|
|
while True:
|
|
if not output_stream or not thread:
|
|
logging.error(f"No output stream/thread for session {session_id}.")
|
|
break
|
|
# Stop streaming when done and there is no pending buffered output.
|
|
if output_stream.done and not output_stream.buffer:
|
|
break
|
|
try:
|
|
if output_stream.buffer:
|
|
delta = output_stream.buffer.pop(0)
|
|
yield bytes(delta, encoding="utf-8")
|
|
else:
|
|
# A short sleep before looping again
|
|
asyncio.run(sleep(0.018))
|
|
except Exception as e:
|
|
logging.exception(f"Exception occurred during streaming for session {session_id}: {e}")
|
|
break
|
|
logging.info(f"Stream completed successfully for session {session_id}.")
|
|
logging.info(f"Completed Assistant Response for session {session_id}:\n{output_stream.response}")
|
|
with stream_lock:
|
|
thread.join()
|
|
# Clean up the session data once done.
|
|
del user_streams[session_id]
|
|
logging.info(f"Stream thread joined and resources cleaned up for session {session_id}.")
|
|
|
|
def process(url, session_id):
|
|
"""
|
|
Process a YouTube URL: parse the video id, retrieve its transcript, and prepare the session for AI processing.
|
|
|
|
Args:
|
|
url (str): The YouTube URL provided by the user.
|
|
session_id (str): The unique session identifier.
|
|
|
|
Returns:
|
|
tuple: (success (bool), message (str or None), status_code (int or None))
|
|
"""
|
|
current_time = datetime.now(pytz.timezone('America/New_York')).strftime('%Y-%m-%d %H:%M:%S')
|
|
logging.info(f"New Entry at {current_time} for session {session_id}")
|
|
logging.info(f"URL: {url}")
|
|
video_id = get_video_id(url)
|
|
if not video_id:
|
|
logging.warning(f"Could not parse video id from URL: {url}")
|
|
return (False, "Couldn't parse video ID from URL. (Are you sure you entered a valid YouTube.com or YouTu.be URL?)", 400)
|
|
logging.info(f"Parsed Video ID: {video_id}")
|
|
transcript = get_auto_transcript(video_id)
|
|
if not transcript:
|
|
logging.error(f"Error: could not retrieve transcript for session {session_id}. Assistant won't be called.")
|
|
return (False, "Successfully parsed video ID from URL, however the transcript was disabled by the video owner or invalid.", 200)
|
|
|
|
# Initialize session data for streaming.
|
|
user_streams[session_id] = {
|
|
'output_stream': None,
|
|
'thread': None
|
|
}
|
|
with stream_lock:
|
|
user_streams[session_id]['output_stream'] = StreamOutput()
|
|
thread = threading.Thread(
|
|
name=f"create_stream_{session_id}",
|
|
target=create_and_stream,
|
|
args=(transcript, session_id)
|
|
)
|
|
user_streams[session_id]['thread'] = thread
|
|
logging.info(f"Stream preparation complete for session {session_id}, sending reply.")
|
|
return (True, None, None)
|
|
|
|
def get_video_id(url):
|
|
"""
|
|
Extract the YouTube video ID from a URL.
|
|
|
|
Args:
|
|
url (str): The YouTube URL.
|
|
|
|
Returns:
|
|
str or None: The video ID if found, otherwise None.
|
|
"""
|
|
youtu_be = r'(?<=youtu.be/)([A-Za-z0-9_-]{11})'
|
|
youtube_com = r'(?<=youtube\.com\/watch\?v=)([A-Za-z0-9_-]{11})'
|
|
id_match = re.search(youtu_be, url)
|
|
if not id_match:
|
|
id_match = re.search(youtube_com, url)
|
|
if not id_match:
|
|
logging.warning(f"Failed to parse video ID from URL: {url}")
|
|
return None
|
|
return id_match.group(1)
|
|
|
|
def get_auto_transcript(video_id):
|
|
"""
|
|
Retrieve and format the transcript from a YouTube video.
|
|
|
|
Args:
|
|
video_id (str): The YouTube video identifier.
|
|
|
|
Returns:
|
|
str or None: The formatted transcript if successful; otherwise None.
|
|
"""
|
|
trans_api_errors = youtube_transcript_api._errors
|
|
try:
|
|
transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=['en'], proxies=None, cookies=None, preserve_formatting=False)
|
|
except trans_api_errors.TranscriptsDisabled as e:
|
|
logging.exception(f"Exception while fetching transcript: {e}")
|
|
return None
|
|
formatter = TextFormatter()
|
|
txt_transcript = formatter.format_transcript(transcript)
|
|
logging.info("Transcript successfully retrieved and formatted.")
|
|
return txt_transcript
|
|
|
|
# Initialize a global output_stream just for main module logging (not used for per-session streaming).
|
|
output_stream = StreamOutput()
|
|
logging.info(f"Main initialized at {datetime.now(pytz.timezone('America/New_York')).strftime('%Y-%m-%d %H:%M:%S')}. Application starting.") |