1+ #!/usr/bin/env python3
2+ """
3+ RunPod Serverless GPU Transcription Handler
4+ Clean version with minimal dependencies for faster build
5+ """
6+
7+ import runpod
8+ from faster_whisper import WhisperModel
9+ import tempfile
10+ import base64
11+ import os
12+ import logging
13+ import re
14+ from datetime import datetime
15+
16+ # Configure logging
17+ logging .basicConfig (level = logging .INFO )
18+ logger = logging .getLogger (__name__ )
19+
20+ # Global model instance (loaded once per worker)
21+ whisper_model = None
22+
23+ def clean_hallucinated_text (text : str ) -> str :
24+ """Clean hallucinated and repetitive text from Whisper output"""
25+ if not text or len (text .strip ()) < 3 :
26+ return ""
27+
28+ # Remove excessive whitespace
29+ text = re .sub (r'\s+' , ' ' , text .strip ())
30+
31+ # Detect repetitive patterns
32+ words = text .split ()
33+ if len (words ) < 4 :
34+ return text
35+
36+ # Check for excessive repetition of phrases
37+ for phrase_len in [2 , 3 , 4 , 5 ]:
38+ if len (words ) >= phrase_len * 3 :
39+ for i in range (len (words ) - phrase_len * 3 + 1 ):
40+ phrase = words [i :i + phrase_len ]
41+
42+ # Count consecutive repetitions
43+ repetitions = 1
44+ pos = i + phrase_len
45+
46+ while pos + phrase_len <= len (words ):
47+ if words [pos :pos + phrase_len ] == phrase :
48+ repetitions += 1
49+ pos += phrase_len
50+ else :
51+ break
52+
53+ # If we find 3+ repetitions, truncate
54+ if repetitions >= 3 :
55+ words = words [:i + phrase_len * 2 ]
56+ break
57+
58+ return ' ' .join (words )
59+
60+ def load_whisper_model ():
61+ """Initialize the Whisper model (called once per worker)"""
62+ global whisper_model
63+
64+ if whisper_model is None :
65+ try :
66+ model_size = os .getenv ("WHISPER_MODEL" , "medium" )
67+ compute_type = os .getenv ("WHISPER_COMPUTE_TYPE" , "float16" )
68+
69+ logger .info (f"Loading faster-whisper model: { model_size } with { compute_type } precision on GPU" )
70+ whisper_model = WhisperModel (
71+ model_size ,
72+ device = "cuda" ,
73+ compute_type = compute_type ,
74+ cpu_threads = 1
75+ )
76+ logger .info ("✅ Faster-whisper GPU model loaded successfully" )
77+
78+ except Exception as e :
79+ logger .error (f"❌ Failed to load whisper model: { e } " )
80+ raise e
81+
82+ return whisper_model
83+
84+ def handler (job ):
85+ """Handle transcription requests from RunPod serverless"""
86+ processing_start_time = datetime .now ()
87+
88+ try :
89+ # Load model if not already loaded
90+ model = load_whisper_model ()
91+
92+ # Extract job inputs
93+ input_data = job ['input' ]
94+ audio_b64 = input_data ['audio_b64' ]
95+ session_id = input_data .get ('session_id' , 'unknown' )
96+ chunk_index = input_data .get ('chunk_index' , 0 )
97+
98+ logger .info (f"🎙️ Processing [{ session_id } :{ chunk_index } ] via GPU serverless" )
99+
100+ # Decode base64 audio data
101+ try :
102+ audio_data = base64 .b64decode (audio_b64 )
103+ except Exception as e :
104+ return {"success" : False , "error" : f"Invalid base64 audio data: { str (e )} " }
105+
106+ logger .info (f"📦 Decoded audio: { len (audio_data )} bytes" )
107+
108+ # Create temporary file for transcription
109+ with tempfile .NamedTemporaryFile (suffix = '.wav' , delete = False ) as temp_file :
110+ temp_file .write (audio_data )
111+ temp_path = temp_file .name
112+
113+ try :
114+ # Transcription with anti-hallucination settings
115+ transcription_start = datetime .now ()
116+
117+ language = os .getenv ("WHISPER_LANGUAGE" , "de" )
118+ beam_size = int (os .getenv ("WHISPER_BEAM_SIZE" , "3" ))
119+ temperature = float (os .getenv ("WHISPER_TEMPERATURE" , "0.2" ))
120+
121+ segments , info = model .transcribe (
122+ temp_path ,
123+ language = language ,
124+ beam_size = beam_size ,
125+ temperature = temperature ,
126+ word_timestamps = True ,
127+ vad_filter = True ,
128+ vad_parameters = dict (
129+ min_silence_duration_ms = 100 ,
130+ min_speech_duration_ms = 100 ,
131+ speech_pad_ms = 100
132+ ),
133+ condition_on_previous_text = False ,
134+ compression_ratio_threshold = 2.4
135+ )
136+
137+ # Convert segments to list
138+ segments_list = list (segments )
139+ transcription_time = (datetime .now () - transcription_start ).total_seconds ()
140+
141+ # Format response with hallucination cleaning
142+ cleaned_segments = []
143+ full_text_parts = []
144+
145+ for segment in segments_list :
146+ # Clean the segment text
147+ original_text = segment .text .strip ()
148+ cleaned_text = clean_hallucinated_text (original_text )
149+
150+ # Skip empty segments after cleaning
151+ if not cleaned_text :
152+ continue
153+
154+ cleaned_segments .append ({
155+ "start" : float (segment .start ),
156+ "end" : float (segment .end ),
157+ "text" : cleaned_text ,
158+ "speaker" : "SPEAKER_00" ,
159+ "confidence" : float (getattr (segment , 'avg_logprob' , 0.0 ))
160+ })
161+
162+ full_text_parts .append (cleaned_text )
163+
164+ # Join cleaned text
165+ full_text = " " .join (full_text_parts )
166+
167+ # Calculate processing metrics
168+ total_processing_time = (datetime .now () - processing_start_time ).total_seconds ()
169+ audio_duration = info .duration
170+ rtf = total_processing_time / audio_duration if audio_duration > 0 else 0
171+
172+ # Response format compatible with existing Railway app
173+ result = {
174+ "text" : full_text ,
175+ "language" : info .language ,
176+ "language_probability" : float (info .language_probability ),
177+ "duration" : float (audio_duration ),
178+ "segments" : cleaned_segments ,
179+ "processing_info" : {
180+ "transcription_time" : transcription_time ,
181+ "total_processing_time" : total_processing_time ,
182+ "real_time_factor" : rtf ,
183+ "model" : os .getenv ("WHISPER_MODEL" , "medium" ),
184+ "compute_type" : os .getenv ("WHISPER_COMPUTE_TYPE" , "float16" ),
185+ "device" : "cuda" ,
186+ "speakers_detected" : len (set (seg ["speaker" ] for seg in cleaned_segments )),
187+ "segments_count" : len (cleaned_segments ),
188+ "serverless" : True
189+ }
190+ }
191+
192+ # Clean log output
193+ logger .info (f"✅ [{ session_id } :{ chunk_index } ] GPU RTF: { rtf :.2f} | { len (cleaned_segments )} segments" )
194+
195+ return result
196+
197+ finally :
198+ # Clean up temporary file
199+ try :
200+ os .unlink (temp_path )
201+ except :
202+ pass
203+
204+ except Exception as e :
205+ logger .error (f"❌ Transcription failed for [{ session_id } :{ chunk_index } ]: { e } " )
206+ return {
207+ "error" : str (e ),
208+ "session_id" : session_id ,
209+ "chunk_index" : chunk_index
210+ }
211+
212+ # RunPod serverless entry point
213+ runpod .serverless .start ({"handler" : handler })
0 commit comments