
google/gemma-4-E2B-it · Hugging Face
We’re on a journey to advance and democratize artificial intelligence through open source and open science.

google/gemma-4-E4B-it · Hugging Face
We’re on a journey to advance and democratize artificial intelligence through open source and open science.

google/gemma-4-26B-A4B-it · Hugging Face
We’re on a journey to advance and democratize artificial intelligence through open source and open science.

google/gemma-4-31B-it · Hugging Face
We’re on a journey to advance and democratize artificial intelligence through open source and open science.
NOeRADiO #1【2016年6月16日更新分】 - ニコニコ動画
NOeRADiO #1【2016年6月16日更新分】 パーソナリティ:中恵光城・かなでももこ更新日:毎週木曜日NOeRADiOマイリスト【mylist/56133501】公式...
起動する際は「streamlit run src\streamlit_app.py」
from __future__ import annotations
import json
import os
from pathlib import Path
from tempfile import NamedTemporaryFile, mkstemp
import subprocess
import streamlit as st
from transcriber_core import DEFAULT_CHUNK_SECONDS, DEFAULT_MODEL_DIR, DEFAULT_MODELS_ROOT
from transcriber_core import GemmaTranscriber, TranscriberSettings, VIDEO_EXTENSIONS
from transcriber_core import resolve_ffmpeg_executable
st.set_page_config(page_title="Gemma4 Playground", page_icon="M", layout="wide")
TRANSCRIBER_RUNTIME_VERSION = 2
def list_model_directories(models_root: Path) -> list[Path]:
if not models_root.is_dir():
return [DEFAULT_MODEL_DIR]
model_dirs = [path for path in models_root.iterdir() if path.is_dir()]
if not model_dirs:
return [DEFAULT_MODEL_DIR]
return sorted(model_dirs, key=lambda path: path.name.lower())
def is_uploaded_video(uploaded_file) -> bool:
mime_type = uploaded_file.type or ""
if mime_type.startswith("video/"):
return True
return Path(uploaded_file.name).suffix.lower() in VIDEO_EXTENSIONS
def get_media_format(uploaded_file) -> str:
suffix = Path(uploaded_file.name).suffix.lower()
video_map = {
".mp4": "video/mp4",
".mov": "video/quicktime",
".mkv": "video/x-matroska",
".webm": "video/webm",
".avi": "video/x-msvideo",
".m4v": "video/mp4",
}
audio_map = {
".mp3": "audio/mpeg",
".wav": "audio/wav",
".m4a": "audio/mp4",
".flac": "audio/flac",
}
if suffix in video_map:
return video_map[suffix]
if suffix in audio_map:
return audio_map[suffix]
return uploaded_file.type or "application/octet-stream"
def render_video_preview(uploaded_file, ffmpeg_path: str) -> None:
preview_key = f"video_preview::{uploaded_file.name}"
preview_path = st.session_state.get(preview_key)
uploaded_bytes = uploaded_file.getvalue()
if preview_path is None or not Path(preview_path).exists():
fd, temp_name = mkstemp(suffix=Path(uploaded_file.name).suffix)
os.close(fd)
Path(temp_name).write_bytes(uploaded_bytes)
preview_path = temp_name
st.session_state[preview_key] = preview_path
playable_path = preview_path
preview_mp4_key = f"{preview_key}::browser_mp4"
browser_mp4_path = st.session_state.get(preview_mp4_key)
if browser_mp4_path is None or not Path(browser_mp4_path).exists():
try:
ffmpeg_exe = resolve_ffmpeg_executable(ffmpeg_path)
fd, temp_mp4 = mkstemp(suffix=".mp4")
os.close(fd)
command = [
ffmpeg_exe,
"-y",
"-i",
preview_path,
"-c:v",
"libx264",
"-preset",
"veryfast",
"-pix_fmt",
"yuv420p",
"-c:a",
"aac",
"-movflags",
"+faststart",
temp_mp4,
]
subprocess.run(command, check=True, capture_output=True)
if Path(temp_mp4).exists() and Path(temp_mp4).stat().st_size > 0:
browser_mp4_path = temp_mp4
st.session_state[preview_mp4_key] = browser_mp4_path
except Exception:
browser_mp4_path = None
if browser_mp4_path is not None and Path(browser_mp4_path).exists():
playable_path = browser_mp4_path
st.video(playable_path, format="video/mp4")
def unload_current_transcriber() -> None:
current_transcriber = st.session_state.get("transcriber_instance")
if current_transcriber is not None:
current_transcriber.unload()
st.session_state["transcriber_instance"] = None
st.session_state["model_settings_key"] = None
def build_max_memory_json(raw_text: str) -> str:
try:
parsed = json.loads(raw_text)
except json.JSONDecodeError:
return raw_text
if not isinstance(parsed, dict):
return raw_text
return json.dumps(parsed, ensure_ascii=True)
def sync_loaded_model(settings: TranscriberSettings) -> None:
desired_key = (
TRANSCRIBER_RUNTIME_VERSION,
str(settings.model_dir),
settings.dtype,
settings.device_map,
settings.max_memory,
)
current_key = st.session_state.get("model_settings_key")
if current_key is not None and current_key != desired_key:
unload_current_transcriber()
def load_transcriber(settings: TranscriberSettings) -> GemmaTranscriber:
settings_key = (
TRANSCRIBER_RUNTIME_VERSION,
str(settings.model_dir),
settings.dtype,
settings.device_map,
settings.max_memory,
)
current_settings = st.session_state.get("model_settings_key")
current_transcriber = st.session_state.get("transcriber_instance")
if current_transcriber is not None and (
not hasattr(current_transcriber, "analyze_image")
or not hasattr(current_transcriber, "_ensure_causal_model")
or not hasattr(current_transcriber, "_ensure_multimodal_model")
):
unload_current_transcriber()
current_transcriber = None
current_settings = None
if current_transcriber is not None and current_settings == settings_key:
return current_transcriber
if current_transcriber is not None:
unload_current_transcriber()
transcriber = GemmaTranscriber(settings)
st.session_state["transcriber_instance"] = transcriber
st.session_state["model_settings_key"] = settings_key
return transcriber
def render_chat_ui(
model_dir: Path,
dtype: str,
device_map: str,
max_memory: str,
ffmpeg_path: str,
) -> None:
st.subheader("Chat")
st.caption("Gemma4 と通常チャットを行います。")
if "chat_messages" not in st.session_state:
st.session_state["chat_messages"] = []
system_prompt = st.text_area(
"System prompt",
value="You are a helpful assistant.",
height=100,
)
enable_thinking = st.checkbox("Enable reasoning", value=False)
show_thinking = st.checkbox("Show reasoning", value=True, disabled=not enable_thinking)
chat_max_new_tokens = st.slider(
"Chat max new tokens",
min_value=64,
max_value=1024,
value=256,
step=64,
)
if st.button("Clear chat history", use_container_width=True):
st.session_state["chat_messages"] = []
st.rerun()
for message in st.session_state["chat_messages"]:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if message["role"] == "assistant" and message.get("thinking"):
with st.expander("Reasoning", expanded=False):
st.text(message["thinking"])
user_input = st.chat_input("Gemma4 に質問する")
if not user_input:
return
st.session_state["chat_messages"].append({"role": "user", "content": user_input})
with st.chat_message("user"):
st.markdown(user_input)
settings = TranscriberSettings(
model_dir=model_dir,
dtype=dtype,
device_map=device_map,
max_memory=max_memory,
ffmpeg_path=ffmpeg_path,
)
try:
with st.spinner("Loading model..."):
model_session = load_transcriber(settings)
with st.chat_message("assistant"):
result = model_session.chat(
st.session_state["chat_messages"],
max_new_tokens=chat_max_new_tokens,
system_prompt=system_prompt,
enable_thinking=enable_thinking,
)
response = result.content or "(empty response)"
st.markdown(response)
if show_thinking and result.thinking:
with st.expander("Reasoning", expanded=False):
st.text(result.thinking)
st.session_state["chat_messages"].append(
{
"role": "assistant",
"content": response,
"thinking": result.thinking,
}
)
except RuntimeError as exc:
st.error(str(exc))
if st.session_state["chat_messages"][-1]["role"] == "user":
st.session_state["chat_messages"].pop()
def render_transcription_ui(
model_dir: Path,
language: str,
mode: str,
chunk_seconds: int,
max_new_tokens: int,
dtype: str,
device_map: str,
max_memory: str,
ffmpeg_path: str,
) -> None:
st.subheader("Transcription")
st.caption("音声または動画ファイルをアップロードして文字起こしします。")
uploaded_file = st.file_uploader(
"Audio or video file",
type=["mp3", "wav", "m4a", "flac", "mp4", "mov", "mkv", "webm", "avi", "m4v"],
accept_multiple_files=False,
)
col1, col2 = st.columns([2, 1])
with col1:
st.write("音声または動画ファイルをアップロードして `Transcribe` を押してください。")
with col2:
transcribe_clicked = st.button("Transcribe", use_container_width=True)
if uploaded_file is not None:
if is_uploaded_video(uploaded_file):
st.video(uploaded_file.getvalue(), format=get_media_format(uploaded_file))
else:
st.audio(uploaded_file.getvalue(), format=get_media_format(uploaded_file))
if not transcribe_clicked:
return
if uploaded_file is None:
st.error("音声または動画ファイルを選択してください。")
return
settings = TranscriberSettings(
model_dir=model_dir,
language=language,
mode=mode,
chunk_seconds=chunk_seconds,
max_new_tokens=max_new_tokens,
dtype=dtype,
device_map=device_map,
max_memory=max_memory,
ffmpeg_path=ffmpeg_path,
)
status = st.status("Preparing transcription...", expanded=True)
progress_placeholder = st.empty()
chunk_results_placeholder = st.empty()
final_result_header_placeholder = st.empty()
final_result_placeholder = st.empty()
try:
with NamedTemporaryFile(
delete=False,
suffix=Path(uploaded_file.name).suffix,
) as temp_file:
temp_path = Path(temp_file.name)
temp_file.write(uploaded_file.getbuffer())
with st.spinner("Loading model..."):
transcriber = load_transcriber(settings)
def update_progress(message: str) -> None:
progress_placeholder.info(message)
status.write(message)
live_chunk_results: list[ChunkTranscriptionResult] = []
def render_chunk_results(results: list[ChunkTranscriptionResult]) -> None:
with chunk_results_placeholder.container():
if not results:
return
st.subheader("Chunk Results")
for result in results:
with st.expander(
f"Chunk {result.chunk_index + 1}",
expanded=result == results[-1],
):
st.markdown("Prompt")
st.code(result.prompt_text, language="text")
st.markdown("Raw")
st.code(result.raw_text or "(empty)", language="text")
st.markdown("Cleaned")
st.code(result.cleaned_text or "(empty)", language="text")
def render_final_result(results: list[ChunkTranscriptionResult]) -> str:
transcript_text = "\n".join(
result.cleaned_text for result in results if result.cleaned_text
).strip()
final_result_header_placeholder.subheader("Final Result")
final_result_placeholder.code(
transcript_text or "(empty)",
language="text",
)
return transcript_text
def handle_chunk_result(result: ChunkTranscriptionResult) -> None:
live_chunk_results.append(result)
render_chunk_results(live_chunk_results)
render_final_result(live_chunk_results)
chunk_results = transcriber.transcribe_chunks(
temp_path,
progress_callback=update_progress,
chunk_result_callback=handle_chunk_result,
)
transcript = render_final_result(chunk_results)
status.update(label="Transcription completed", state="complete")
render_chunk_results(chunk_results)
st.download_button(
"Download transcript",
data=transcript.encode("utf-8"),
file_name=f"{Path(uploaded_file.name).stem}_transcript.txt",
mime="text/plain",
use_container_width=True,
)
except RuntimeError as exc:
status.update(label="Transcription failed", state="error")
st.error(str(exc))
finally:
if "temp_path" in locals() and temp_path.exists():
temp_path.unlink()
def render_image_ui(
model_dir: Path,
language: str,
max_new_tokens: int,
dtype: str,
device_map: str,
max_memory: str,
ffmpeg_path: str,
) -> None:
st.subheader("Image")
st.caption("画像をアップロードして内容の説明や文字認識を行います。")
uploaded_file = st.file_uploader(
"Image file",
type=["png", "jpg", "jpeg", "webp", "bmp"],
accept_multiple_files=False,
key="image_uploader",
)
image_prompt = st.text_area(
"Image prompt",
value="この画像を詳しく説明してください。文字があれば読み取ってください。",
height=100,
)
col1, col2 = st.columns([2, 1])
with col1:
st.write("画像をアップロードして `Analyze Image` を押してください。")
with col2:
analyze_clicked = st.button("Analyze Image", use_container_width=True)
if uploaded_file is not None:
st.image(uploaded_file, use_container_width=True)
if not analyze_clicked:
return
if uploaded_file is None:
st.error("画像ファイルを選択してください。")
return
settings = TranscriberSettings(
model_dir=model_dir,
language=language,
dtype=dtype,
device_map=device_map,
max_memory=max_memory,
ffmpeg_path=ffmpeg_path,
)
try:
with NamedTemporaryFile(
delete=False,
suffix=Path(uploaded_file.name).suffix,
) as temp_file:
temp_path = Path(temp_file.name)
temp_file.write(uploaded_file.getbuffer())
with st.spinner("Loading model..."):
transcriber = load_transcriber(settings)
with st.spinner("Analyzing image..."):
result = transcriber.analyze_image(
temp_path,
prompt=image_prompt,
language=language,
max_new_tokens=max_new_tokens,
)
st.subheader("Result")
st.text_area("Image Analysis", value=result, height=320)
except RuntimeError as exc:
st.error(str(exc))
finally:
if "temp_path" in locals() and temp_path.exists():
temp_path.unlink()
def render_video_ui(
model_dir: Path,
language: str,
max_new_tokens: int,
dtype: str,
device_map: str,
max_memory: str,
ffmpeg_path: str,
) -> None:
st.subheader("Video")
st.caption("動画をアップロードして内容の説明や画面内テキストの読取りを行います。")
uploaded_file = st.file_uploader(
"Video file",
type=["mp4", "mov", "mkv", "webm", "avi", "m4v"],
accept_multiple_files=False,
key="video_uploader",
)
video_prompt = st.text_area(
"Video prompt",
value="この動画の内容を時系列で分かりやすく説明してください。重要な場面と表示テキストも含めてください。",
height=100,
)
col1, col2 = st.columns([2, 1])
with col1:
st.write("動画をアップロードして `Analyze Video` を押してください。")
with col2:
analyze_clicked = st.button("Analyze Video", use_container_width=True)
if uploaded_file is not None:
render_video_preview(uploaded_file, ffmpeg_path)
if not analyze_clicked:
return
if uploaded_file is None:
st.error("動画ファイルを選択してください。")
return
settings = TranscriberSettings(
model_dir=model_dir,
language=language,
dtype=dtype,
device_map=device_map,
max_memory=max_memory,
ffmpeg_path=ffmpeg_path,
)
try:
with NamedTemporaryFile(
delete=False,
suffix=Path(uploaded_file.name).suffix,
) as temp_file:
temp_path = Path(temp_file.name)
temp_file.write(uploaded_file.getbuffer())
with st.spinner("Loading model..."):
transcriber = load_transcriber(settings)
with st.spinner("Analyzing video..."):
chunk_results = transcriber.analyze_video_chunks(
temp_path,
prompt=video_prompt,
language=language,
max_new_tokens=max_new_tokens,
)
result = transcriber.analyze_video(
temp_path,
prompt=video_prompt,
language=language,
max_new_tokens=max_new_tokens,
)
if chunk_results:
st.subheader("Segment Results")
for result_item in chunk_results:
label = (
f"Segment {result_item.chunk_index + 1} "
f"({result_item.start_seconds:.0f}s - {result_item.end_seconds:.0f}s)"
)
with st.expander(label, expanded=False):
st.markdown("Prompt")
st.code(result_item.prompt_text, language="text")
st.markdown("Raw")
st.code(result_item.raw_text or "(empty)", language="text")
st.markdown("Parsed")
st.code(result_item.cleaned_text or "(empty)", language="text")
st.subheader("Result")
st.text_area("Video Analysis", value=result, height=320)
except RuntimeError as exc:
st.error(str(exc))
finally:
if "temp_path" in locals() and temp_path.exists():
temp_path.unlink()
def main() -> None:
st.title("Gemma4 Playground")
st.caption("ローカル Gemma4 モデルでチャット、文字起こし、画像認識、動画説明を切り替えて試します。")
with st.sidebar:
st.subheader("Settings")
ui_mode = st.radio("UI mode", options=["Chat", "Transcription", "Image", "Video"], index=0)
available_models = list_model_directories(DEFAULT_MODELS_ROOT)
model_names = [path.name for path in available_models]
default_index = 0
if DEFAULT_MODEL_DIR in available_models:
default_index = available_models.index(DEFAULT_MODEL_DIR)
selected_model_name = st.selectbox(
"Model",
options=model_names,
index=default_index,
)
model_dir = next(
path for path in available_models if path.name == selected_model_name
)
st.caption(f"Model path: {model_dir}")
dtype = st.selectbox(
"Torch dtype",
options=["auto", "bfloat16", "float16", "float32"],
index=0,
)
device_map = st.text_input("Device map", value="auto")
max_memory = st.text_area(
"Max memory JSON",
value="",
height=100,
help="1 GPU で使うなら空のままがおすすめです。複数 GPU で分散したい時だけ JSON を入れてください。",
)
ffmpeg_path = st.text_input(
"ffmpeg path",
value="",
help="Optional full path to ffmpeg.exe for video transcription.",
)
effective_max_memory = build_max_memory_json(max_memory)
if effective_max_memory != max_memory and effective_max_memory.strip():
st.caption(f"Effective max_memory: {effective_max_memory}")
elif not effective_max_memory.strip():
st.caption("Effective max_memory: auto")
if ui_mode == "Transcription":
language = st.selectbox("Language", options=["ja", "en"], index=0)
mode = st.selectbox("Mode", options=["speech", "song"], index=0)
chunk_seconds = st.slider(
"Chunk seconds",
min_value=10,
max_value=30,
value=DEFAULT_CHUNK_SECONDS,
step=5,
)
max_new_tokens = st.slider(
"Transcription max new tokens",
min_value=64,
max_value=1024,
value=256,
step=64,
)
elif ui_mode == "Image":
language = st.selectbox("Language", options=["ja", "en"], index=0)
mode = "speech"
chunk_seconds = DEFAULT_CHUNK_SECONDS
max_new_tokens = st.slider(
"Image max new tokens",
min_value=64,
max_value=1024,
value=256,
step=64,
)
elif ui_mode == "Video":
language = st.selectbox("Language", options=["ja", "en"], index=0)
mode = "speech"
chunk_seconds = DEFAULT_CHUNK_SECONDS
max_new_tokens = st.slider(
"Video max new tokens",
min_value=64,
max_value=1024,
value=256,
step=64,
)
else:
language = "ja"
mode = "speech"
chunk_seconds = DEFAULT_CHUNK_SECONDS
max_new_tokens = 256
sync_loaded_model(
TranscriberSettings(
model_dir=model_dir,
dtype=dtype,
device_map=device_map,
max_memory=effective_max_memory,
ffmpeg_path=ffmpeg_path,
)
)
if ui_mode == "Chat":
render_chat_ui(
model_dir=model_dir,
dtype=dtype,
device_map=device_map,
max_memory=effective_max_memory,
ffmpeg_path=ffmpeg_path,
)
elif ui_mode == "Image":
render_image_ui(
model_dir=model_dir,
language=language,
max_new_tokens=max_new_tokens,
dtype=dtype,
device_map=device_map,
max_memory=effective_max_memory,
ffmpeg_path=ffmpeg_path,
)
elif ui_mode == "Video":
render_video_ui(
model_dir=model_dir,
language=language,
max_new_tokens=max_new_tokens,
dtype=dtype,
device_map=device_map,
max_memory=effective_max_memory,
ffmpeg_path=ffmpeg_path,
)
else:
render_transcription_ui(
model_dir=model_dir,
language=language,
mode=mode,
chunk_seconds=chunk_seconds,
max_new_tokens=max_new_tokens,
dtype=dtype,
device_map=device_map,
max_memory=effective_max_memory,
ffmpeg_path=ffmpeg_path,
)
if __name__ == "__main__":
main()
from __future__ import annotations
from dataclasses import dataclass
import gc
import json
import os
from pathlib import Path
import re
import shutil
import subprocess
from tempfile import TemporaryDirectory
from typing import Callable, Iterable
import wave
DEFAULT_MODEL_DIR = Path(r"H:\LLM_Models\safetensor\gemma4\E2B-it")
DEFAULT_MODELS_ROOT = DEFAULT_MODEL_DIR.parent
DEFAULT_SAMPLE_RATE = 16_000
DEFAULT_CHUNK_SECONDS = 30
DEFAULT_VIDEO_CHUNK_SECONDS = 60
DEFAULT_TEMPERATURE = 1.0
DEFAULT_TOP_P = 0.95
DEFAULT_TOP_K = 64
AUDIO_EXTENSIONS = {".mp3", ".wav", ".m4a", ".flac", ".aac", ".ogg", ".opus"}
VIDEO_EXTENSIONS = {".mp4", ".mov", ".mkv", ".webm", ".avi", ".m4v"}
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
@dataclass(frozen=True)
class TranscriberSettings:
model_dir: Path = DEFAULT_MODEL_DIR
language: str = "ja"
mode: str = "speech"
chunk_seconds: int = DEFAULT_CHUNK_SECONDS
max_new_tokens: int = 256
dtype: str = "auto"
device_map: str = "auto"
max_memory: str = ""
ffmpeg_path: str = ""
sample_rate: int = DEFAULT_SAMPLE_RATE
def lazy_imports():
try:
import librosa
import numpy as np
import torch
from transformers import (
AutoModelForCausalLM,
AutoModelForMultimodalLM,
AutoProcessor,
)
except ModuleNotFoundError as exc:
missing = exc.name or "required package"
raise RuntimeError(
"Missing dependency: "
f"{missing}\n"
"Run `powershell -ExecutionPolicy Bypass -File "
"D:\\Workspace\\Python\\LLM_GEMMA4_TEST\\setup.ps1` first."
) from exc
return (
librosa,
np,
torch,
AutoModelForCausalLM,
AutoModelForMultimodalLM,
AutoProcessor,
)
def resolve_torch_dtype(torch_module, dtype_name: str):
if dtype_name == "auto":
return "auto"
return getattr(torch_module, dtype_name)
def parse_max_memory(max_memory_text: str) -> dict[str, str] | None:
if not max_memory_text.strip():
return None
try:
parsed = json.loads(max_memory_text)
except json.JSONDecodeError as exc:
raise RuntimeError(
"Invalid max_memory JSON. Example: "
'{"cuda:0":"16GiB","cuda:1":"10GiB","cpu":"32GiB"}'
) from exc
if not isinstance(parsed, dict) or not parsed:
raise RuntimeError(
"max_memory must be a JSON object like "
'{"cuda:0":"16GiB","cuda:1":"10GiB","cpu":"32GiB"}'
)
normalized: dict[int | str, str] = {}
for key, value in parsed.items():
if not isinstance(key, str) or not isinstance(value, str):
raise RuntimeError("max_memory keys and values must be strings.")
lowered = key.lower()
if lowered.startswith("cuda:"):
device_index = lowered.split(":", 1)[1]
if not device_index.isdigit():
raise RuntimeError(
"CUDA device keys must look like cuda:0 or cuda:1."
)
normalized[int(device_index)] = value
elif lowered.isdigit():
normalized[int(lowered)] = value
elif lowered in {"cpu", "mps", "disk"}:
normalized[lowered] = value
else:
raise RuntimeError(
"Unsupported max_memory device key. Use 0, 1, cuda:0, cuda:1, cpu, mps, or disk."
)
return normalized
def load_audio(librosa_module, audio_path: Path, sample_rate: int):
try:
waveform, actual_sr = librosa_module.load(
audio_path.as_posix(), sr=sample_rate, mono=True
)
except Exception as exc:
raise RuntimeError(
"Failed to load audio. Install ffmpeg if MP3 decoding is unavailable.\n"
"You can also convert the file to WAV and retry.\n"
f"Details: {exc}"
) from exc
if waveform.size == 0:
raise RuntimeError("The audio file is empty.")
return waveform, actual_sr
def is_video_file(path: Path) -> bool:
return path.suffix.lower() in VIDEO_EXTENSIONS
def extract_audio_from_video(
media_path: Path,
output_path: Path,
sample_rate: int,
ffmpeg_path: str = "",
) -> None:
resolved_ffmpeg = resolve_ffmpeg_executable(ffmpeg_path)
command = [
resolved_ffmpeg,
"-y",
"-i",
str(media_path),
"-vn",
"-ac",
"1",
"-ar",
str(sample_rate),
str(output_path),
]
try:
completed = subprocess.run(
command,
check=True,
capture_output=True,
)
except FileNotFoundError as exc:
raise RuntimeError(
"ffmpeg was not found. Set an ffmpeg.exe path in the UI/CLI, or install it in PATH."
) from exc
except subprocess.CalledProcessError as exc:
details_bytes = exc.stderr or exc.stdout or b""
details = details_bytes.decode("utf-8", errors="replace").strip()
raise RuntimeError(
"Failed to extract audio from the video file with ffmpeg.\n"
f"Details: {details}"
) from exc
if not output_path.exists() or output_path.stat().st_size == 0:
details_bytes = completed.stderr or completed.stdout or b""
details = details_bytes.decode("utf-8", errors="replace").strip()
raise RuntimeError(
"ffmpeg finished but no audio file was created.\n"
f"Details: {details}"
)
def resolve_ffmpeg_executable(ffmpeg_path: str = "") -> str:
ffmpeg_candidates: list[str] = []
if ffmpeg_path.strip():
ffmpeg_candidates.append(ffmpeg_path.strip())
ffmpeg_candidates.extend(
[
"ffmpeg",
r"C:\ffmpeg\bin\ffmpeg.exe",
r"C:\Program Files\ffmpeg\bin\ffmpeg.exe",
r"C:\ProgramData\chocolatey\bin\ffmpeg.exe",
str(Path.home() / "scoop" / "apps" / "ffmpeg" / "current" / "bin" / "ffmpeg.exe"),
]
)
for candidate in ffmpeg_candidates:
if candidate == "ffmpeg":
located = shutil.which("ffmpeg")
if located:
return located
continue
expanded = Path(candidate).expanduser()
if expanded.exists():
return str(expanded)
raise RuntimeError(
"ffmpeg was not found. Set an ffmpeg.exe path in the UI/CLI, or install it in PATH."
)
def resolve_ffprobe_executable(ffmpeg_path: str = "") -> str:
candidates: list[str] = []
if ffmpeg_path.strip():
ffmpeg_file = Path(ffmpeg_path.strip()).expanduser()
if ffmpeg_file.name.lower() == "ffmpeg.exe":
candidates.append(str(ffmpeg_file.with_name("ffprobe.exe")))
candidates.append(str(ffmpeg_file))
candidates.extend(
[
"ffprobe",
r"C:\ffmpeg\bin\ffprobe.exe",
r"C:\Program Files\ffmpeg\bin\ffprobe.exe",
r"C:\ProgramData\chocolatey\bin\ffprobe.exe",
str(Path.home() / "scoop" / "apps" / "ffmpeg" / "current" / "bin" / "ffprobe.exe"),
]
)
for candidate in candidates:
if candidate == "ffprobe":
located = shutil.which("ffprobe")
if located:
return located
continue
expanded = Path(candidate).expanduser()
if expanded.exists():
return str(expanded)
raise RuntimeError(
"ffprobe was not found. Set an ffmpeg.exe path in the UI/CLI, or install ffmpeg/ffprobe in PATH."
)
def get_video_duration_seconds(media_path: Path, ffmpeg_path: str = "") -> float | None:
ffprobe_path = resolve_ffprobe_executable(ffmpeg_path)
command = [
ffprobe_path,
"-v",
"error",
"-show_entries",
"format=duration",
"-of",
"json",
str(media_path),
]
try:
completed = subprocess.run(
command,
check=True,
capture_output=True,
)
payload = json.loads(completed.stdout.decode("utf-8", errors="replace"))
duration = payload.get("format", {}).get("duration")
if duration is None:
return None
return float(duration)
except Exception:
return None
def extract_frames_from_video(
media_path: Path,
output_dir: Path,
ffmpeg_path: str = "",
max_frames: int = 8,
start_time: float | None = None,
duration_seconds: float | None = None,
) -> list[Path]:
resolved_ffmpeg = resolve_ffmpeg_executable(ffmpeg_path)
output_dir.mkdir(parents=True, exist_ok=True)
frame_pattern = output_dir / "frame_%03d.jpg"
effective_duration = duration_seconds
if effective_duration is None:
effective_duration = get_video_duration_seconds(media_path, ffmpeg_path)
if effective_duration is not None and effective_duration > 0:
fps_value = max_frames / effective_duration
else:
fps_value = 1 / 8
command = [resolved_ffmpeg, "-y"]
if start_time is not None and start_time > 0:
command.extend(["-ss", f"{start_time:.3f}"])
command.extend(["-i", str(media_path)])
if duration_seconds is not None and duration_seconds > 0:
command.extend(["-t", f"{duration_seconds:.3f}"])
command.extend(
[
"-vf",
f"fps={fps_value:.6f},scale=1024:-1",
"-frames:v",
str(max_frames),
str(frame_pattern),
]
)
try:
completed = subprocess.run(
command,
check=True,
capture_output=True,
)
except subprocess.CalledProcessError as exc:
details_bytes = exc.stderr or exc.stdout or b""
details = details_bytes.decode("utf-8", errors="replace").strip()
raise RuntimeError(
"Failed to extract frames from the video file with ffmpeg.\n"
f"Details: {details}"
) from exc
frames = sorted(output_dir.glob("frame_*.jpg"))
if not frames:
details_bytes = completed.stderr or completed.stdout or b""
details = details_bytes.decode("utf-8", errors="replace").strip()
raise RuntimeError(
"ffmpeg finished but no video frames were created.\n"
f"Details: {details}"
)
return frames
def split_video_segment(
media_path: Path,
output_path: Path,
start_time: float,
duration_seconds: float,
ffmpeg_path: str = "",
) -> None:
resolved_ffmpeg = resolve_ffmpeg_executable(ffmpeg_path)
command = [
resolved_ffmpeg,
"-y",
"-ss",
f"{start_time:.3f}",
"-i",
str(media_path),
"-t",
f"{duration_seconds:.3f}",
"-c:v",
"libx264",
"-preset",
"veryfast",
"-crf",
"23",
"-c:a",
"aac",
"-b:a",
"128k",
str(output_path),
]
try:
completed = subprocess.run(
command,
check=True,
capture_output=True,
)
except subprocess.CalledProcessError as exc:
details_bytes = exc.stderr or exc.stdout or b""
details = details_bytes.decode("utf-8", errors="replace").strip()
raise RuntimeError(
"Failed to split the video file with ffmpeg.\n"
f"Details: {details}"
) from exc
if not output_path.exists() or output_path.stat().st_size == 0:
details_bytes = completed.stderr or completed.stdout or b""
details = details_bytes.decode("utf-8", errors="replace").strip()
raise RuntimeError(
"ffmpeg finished but no video segment was created.\n"
f"Details: {details}"
)
def iter_chunks(waveform, chunk_size: int) -> Iterable[tuple[int, object]]:
chunk_index = 0
for start in range(0, len(waveform), chunk_size):
yield chunk_index, waveform[start : start + chunk_size]
chunk_index += 1
def build_transcription_prompt(language: str, mode: str) -> str:
if mode == "song":
return (
"Transcribe the following singing in its original language. "
"Follow these specific instructions for formatting the answer:\n"
"* Only output the lyrics.\n"
"* Do not translate.\n"
"* Do not add any other text.\n"
"* Do not add timestamps, explanations, or speaker labels."
)
return (
"Transcribe the following speech segment in its original language. "
"Follow these specific instructions for formatting the answer:\n"
"* Only output the transcription, with no newlines.\n"
"* Do not add any other text.\n"
"* When transcribing numbers, write the digits, i.e. write 1.7 and not one point seven, and write 3 instead of three."
)
def describe_missing_chat_template_files(model_dir: Path) -> str:
expected_files = [
"chat_template.jinja",
"preprocessor_config.json",
"special_tokens_map.json",
"tokenizer.model",
]
missing = [name for name in expected_files if not (model_dir / name).exists()]
if not missing:
return "No obvious missing template files were detected."
return "Missing files: " + ", ".join(missing)
def build_transcription_prompt_variants(language: str, mode: str) -> list[str]:
language_key = language.lower()
if mode == "song":
if language_key == "ja":
return [
build_transcription_prompt(language, mode),
"歌詞だけ。",
"歌詞のみ。",
]
return [
build_transcription_prompt(language, mode),
"Output only the lyrics.",
]
if language_key == "ja":
return [
build_transcription_prompt(language, mode),
"文字起こしだけ。",
"話している内容だけを書いてください。",
]
return [
build_transcription_prompt(language, mode),
"Output only the transcript.",
]
def build_audio_chat_messages(prompt_text: str) -> list[dict[str, object]]:
return [
{
"role": "user",
"content": [
{"type": "audio", "audio": "__audio__"},
{"type": "text", "text": prompt_text.strip()},
],
}
]
def build_audio_transcription_messages(
audio_source: str,
prompt_text: str,
) -> list[dict[str, object]]:
return [
{
"role": "user",
"content": [
{"type": "audio", "audio": audio_source},
{"type": "text", "text": prompt_text.strip()},
],
}
]
def build_chat_prompt(system_prompt: str, messages: list[dict[str, str]]) -> str:
lines: list[str] = []
if system_prompt.strip():
lines.append(f"System:<|think|> {system_prompt.strip()}")
for message in messages:
role = message["role"]
content = message["content"].strip()
if not content:
continue
if role == "user":
lines.append(f"User: {content}")
else:
lines.append(f"Assistant: {content}")
lines.append("Assistant:")
return "\n".join(lines)
def build_chat_messages(
system_prompt: str,
messages: list[dict[str, str]],
) -> list[dict[str, object]]:
chat_messages: list[dict[str, object]] = []
if system_prompt.strip():
chat_messages.append(
{
"role": "system",
"content": [{"type": "text", "text": system_prompt.strip()}],
}
)
for message in messages:
role = message.get("role", "").strip()
content = message.get("content", "").strip()
if role not in {"user", "assistant"} or not content:
continue
chat_messages.append(
{
"role": role,
"content": [{"type": "text", "text": content}],
}
)
return chat_messages
def build_text_chat_template_messages(
system_prompt: str,
messages: list[dict[str, str]],
) -> list[dict[str, str]]:
chat_messages: list[dict[str, str]] = []
if system_prompt.strip():
chat_messages.append({"role": "system", "content": system_prompt.strip()})
for message in messages:
role = message.get("role", "").strip()
content = message.get("content", "").strip()
if role not in {"user", "assistant"} or not content:
continue
chat_messages.append({"role": role, "content": content})
return chat_messages
def build_image_prompt(language: str, user_prompt: str) -> str:
prompt = user_prompt.strip()
if prompt:
return prompt
if language.lower() == "ja":
return "この画像を詳しく説明してください。文字があれば読み取ってください。"
return "Describe this image in detail and read any visible text."
def build_video_prompt(language: str, user_prompt: str) -> str:
prompt = user_prompt.strip()
if prompt:
return prompt
return "Describe this video."
def build_video_chunk_prompt(
language: str,
user_prompt: str,
start_seconds: float,
end_seconds: float,
) -> str:
base_prompt = build_video_prompt(language, user_prompt)
return (
f"This is the video segment from {start_seconds:.0f}s to {end_seconds:.0f}s. "
f"{base_prompt}"
)
def build_video_summary_prompt(language: str, segment_texts: list[str]) -> str:
joined_segments = "\n\n".join(
f"Segment {index + 1}:\n{text}"
for index, text in enumerate(segment_texts)
if text.strip()
)
if language.lower() == "ja":
return (
"以下は動画の各区間の説明です。全体として自然につながるように、"
"冒頭から終盤まで統合した最終説明を日本語で作成してください。本文だけを書いてください。\n\n"
f"{joined_segments}"
)
return (
"The following are descriptions of each video segment. "
"Write one coherent final description of the full video from beginning to end. Output only the description.\n\n"
f"{joined_segments}"
)
def clean_chat_text(text: str) -> str:
separators = (
"\nUser:",
"\nAssistant:",
"User:",
"Assistant:",
"<eos>",
)
cleaned = text.strip()
for separator in separators:
if separator in cleaned:
cleaned = cleaned.split(separator, 1)[0].strip()
return cleaned
def normalize_parsed_response(parsed) -> str:
if parsed is None:
return ""
if isinstance(parsed, str):
return parsed.strip()
if isinstance(parsed, list):
parts: list[str] = []
for part in parsed:
normalized = normalize_parsed_response(part)
if normalized:
parts.append(normalized)
return "\n".join(parts).strip()
if isinstance(parsed, dict):
role = parsed.get("role")
content = parsed.get("content")
if isinstance(content, str):
stripped = content.strip()
if stripped:
return stripped
if isinstance(content, list):
normalized_content = normalize_parsed_response(content)
if normalized_content:
return normalized_content
for key in ("response", "text", "content", "output", "answer", "final"):
value = parsed.get(key)
normalized_value = normalize_parsed_response(value)
if normalized_value:
return normalized_value
if role == "assistant":
return ""
return str(parsed).strip()
return str(parsed).strip()
def parse_chat_result(parsed, raw_text: str) -> ChatResult:
raw_stripped = raw_text.strip()
if isinstance(parsed, dict):
content = normalize_parsed_response(parsed.get("content"))
thinking = normalize_parsed_response(parsed.get("thinking"))
if thinking and raw_stripped:
content_without_thinking = raw_stripped
if "<|channel|>thought" in content_without_thinking and "<channel|>" in content_without_thinking:
content_without_thinking = content_without_thinking.split("<channel|>", 1)[1]
content_without_thinking = re.sub(r"<turn\|>\s*$", "", content_without_thinking).strip()
content_without_thinking = strip_special_only_response(content_without_thinking)
if content_without_thinking:
content = content or content_without_thinking
if content or thinking:
return ChatResult(content=content, thinking=thinking, raw_text=raw_stripped)
normalized = normalize_parsed_response(parsed)
return ChatResult(content=normalized, thinking="", raw_text=raw_stripped)
def strip_special_only_response(text: str) -> str:
stripped = text.strip()
if not stripped:
return ""
normalized = re.sub(r"<[^>\n]+>", "", stripped).strip()
if not normalized:
return ""
return stripped
def is_empty_or_special_response(text: str) -> bool:
return not strip_special_only_response(text)
def clean_transcript_text(text: str, prompt: str) -> str:
timestamp_line_pattern = re.compile(
r"^\d{2}:\d{2}:\d{2},\d{3}(?:\s*-->\s*\d{2}:\d{2}:\d{2},\d{3})?$"
)
dot_only_pattern = re.compile(r"^[.\s]{5,}$")
banned_prefixes = (
"* ",
"- ",
"Please respond only",
"No comments.",
"here are the lyrics",
"Here are the lyrics",
"If you want to sing along",
"This video and its lyrics are not for commercial use.",
"Only output",
"Do not add",
"Output only",
"No explanation",
"Transcribe this",
"Transcribe the following",
"Follow these specific instructions",
"この歌声を",
"この音声を",
"この歌の歌詞を書き起こしてください",
"元の言語のまま、歌詞本文だけを出力してください",
"文字起こし結果のみを書いてください",
"話している内容だけを書いてください",
"歌詞だけ",
"歌詞のみ",
"歌詞を日本語にしてください",
"字幕形式にしないでください",
"タイムスタンプ",
"歌詞の本文だけを出力してください",
"本文だけを出力してください",
)
prompt_lines = {
line.strip()
for line in prompt.splitlines()
if line.strip() and line.strip() != "<|audio|>"
}
cleaned_lines: list[str] = []
for raw_line in text.splitlines():
line = raw_line.strip()
if not line:
continue
if set(line) == {"-"}:
continue
if timestamp_line_pattern.match(line):
continue
if dot_only_pattern.match(line):
continue
if line in prompt_lines:
continue
if any(line.startswith(prefix) for prefix in banned_prefixes):
continue
cleaned_lines.append(line)
return "\n".join(cleaned_lines).strip()
def salvage_transcript_text(text: str, prompt: str) -> str:
cleaned = text.replace(prompt, "").replace("<|audio|>", "").strip()
lines: list[str] = []
for raw_line in cleaned.splitlines():
line = raw_line.strip()
if not line:
continue
if line.startswith("Please respond only"):
continue
if line.startswith("No comments."):
continue
if line.lower().startswith("here are the lyrics"):
continue
if line.startswith("この歌の歌詞を書き起こしてください"):
continue
if line.startswith("元の言語のまま、歌詞本文だけを出力してください"):
continue
if line.startswith("文字起こし結果のみを書いてください"):
continue
if line.startswith("If you want to sing along"):
continue
if line.startswith("This video and its lyrics are not for commercial use."):
continue
if "Do not add any other text" in line:
continue
if set(line) == {"-"}:
continue
lines.append(line)
return "\n".join(lines).strip()
def normalize_transcript_output(text: str, language: str, mode: str) -> str:
normalized = text.strip()
if not normalized:
return normalized
if language.lower() == "ja":
if mode == "song":
lines = [line.replace(" ", "") for line in normalized.splitlines()]
return "\n".join(line for line in lines if line).strip()
return normalized.replace(" ", "")
return re.sub(r"[ \t]+", " ", normalized)
def is_low_quality_transcript(text: str) -> bool:
stripped = text.strip()
if not stripped:
return True
if re.fullmatch(r"[\d年月日時分秒、。,\-/: ]+", stripped):
return True
if len(stripped) <= 20:
repeated = re.fullmatch(r"(.{1,10})\1{1,}", stripped)
if repeated:
return True
if "\n" not in stripped:
parts = [part for part in re.split(r"(。|、|\s+)", stripped) if part and not part.isspace()]
text_only_parts = [part for part in parts if part not in {"。", "、"}]
if len(text_only_parts) >= 2 and len(set(text_only_parts)) == 1:
return True
return False
def dedupe_repeated_phrases(text: str) -> str:
stripped = text.strip()
if not stripped:
return stripped
for unit_size in range(1, min(20, len(stripped) // 2 + 1)):
unit = stripped[:unit_size]
if unit * (len(stripped) // unit_size) == stripped[: unit_size * (len(stripped) // unit_size)]:
remainder = stripped[unit_size * (len(stripped) // unit_size) :]
if not remainder:
return unit
lines: list[str] = []
prev = None
for line in stripped.splitlines():
if line == prev:
continue
lines.append(line)
prev = line
return "\n".join(lines).strip()
@dataclass(frozen=True)
class ChunkTranscriptionResult:
chunk_index: int
prompt_text: str
raw_text: str
cleaned_text: str
@dataclass(frozen=True)
class VideoAnalysisChunkResult:
chunk_index: int
start_seconds: float
end_seconds: float
prompt_text: str
raw_text: str
cleaned_text: str
@dataclass(frozen=True)
class ChatResult:
content: str
thinking: str
raw_text: str
class GemmaTranscriber:
def __init__(self, settings: TranscriberSettings):
self.settings = settings
(
librosa,
np,
torch,
AutoModelForCausalLM,
AutoModelForMultimodalLM,
AutoProcessor,
) = lazy_imports()
self.librosa = librosa
self.np = np
self.torch = torch
self.AutoModelForCausalLM = AutoModelForCausalLM
self.AutoModelForMultimodalLM = AutoModelForMultimodalLM
if not settings.model_dir.is_dir():
raise RuntimeError(f"Model directory not found: {settings.model_dir}")
if settings.mode not in {"speech", "song"}:
raise RuntimeError(f"Unsupported mode: {settings.mode}")
if settings.chunk_seconds > 30:
raise RuntimeError(
"Gemma4 audio input supports up to 30 seconds per chunk. "
f"Current chunk_seconds: {settings.chunk_seconds}"
)
dtype = resolve_torch_dtype(torch, settings.dtype)
max_memory = parse_max_memory(settings.max_memory)
self.model_dtype = dtype
self.model_max_memory = max_memory
self.processor = AutoProcessor.from_pretrained(settings.model_dir)
self.model = None
self.model_kind = ""
self.model_device = None
self._load_model("multimodal")
def _build_model_kwargs(self):
model_kwargs = {
"pretrained_model_name_or_path": self.settings.model_dir,
"dtype": self.model_dtype,
"device_map": self.settings.device_map,
}
if self.model_max_memory is not None:
model_kwargs["max_memory"] = self.model_max_memory
return model_kwargs
def _load_model(self, model_kind: str) -> None:
if self.model is not None and self.model_kind == model_kind:
return
if self.model is not None:
try:
self.model.cpu()
except Exception:
pass
del self.model
self.model = None
gc.collect()
if self.torch.cuda.is_available():
self.torch.cuda.empty_cache()
try:
self.torch.cuda.ipc_collect()
except Exception:
pass
model_kwargs = self._build_model_kwargs()
if model_kind == "causal":
self.model = self.AutoModelForCausalLM.from_pretrained(**model_kwargs)
else:
self.model = self.AutoModelForMultimodalLM.from_pretrained(**model_kwargs)
self.model_kind = model_kind
self.model_device = self._resolve_model_device()
def _ensure_multimodal_model(self) -> None:
self._load_model("multimodal")
def _ensure_causal_model(self) -> None:
self._load_model("causal")
def _resolve_model_device(self):
hf_device_map = getattr(self.model, "hf_device_map", None)
if isinstance(hf_device_map, dict):
cuda_indexes: list[int] = []
for device in hf_device_map.values():
if isinstance(device, int):
cuda_indexes.append(device)
continue
if isinstance(device, str) and device.startswith("cuda:"):
suffix = device.split(":", 1)[1]
if suffix.isdigit():
cuda_indexes.append(int(suffix))
if cuda_indexes:
return self.torch.device(f"cuda:{min(cuda_indexes)}")
model_device = getattr(self.model, "device", None)
if model_device is not None:
return model_device
return next(self.model.parameters()).device
def _move_inputs(self, inputs):
moved = {}
for key, value in inputs.items():
if hasattr(value, "to"):
target_dtype = None
if hasattr(value, "dtype") and getattr(value.dtype, "is_floating_point", False):
target_dtype = self.model.dtype
if target_dtype is not None:
moved[key] = value.to(self.model_device, dtype=target_dtype)
else:
moved[key] = value.to(self.model_device)
else:
moved[key] = value
return moved
def _cleanup_generation_tensors(self, *objects) -> None:
for obj in objects:
try:
del obj
except Exception:
pass
gc.collect()
if self.torch.cuda.is_available():
self.torch.cuda.empty_cache()
def _render_audio_prompt(self, prompt_text: str) -> str:
tokenizer = getattr(self.processor, "tokenizer", None)
if tokenizer is not None and hasattr(tokenizer, "apply_chat_template"):
try:
rendered = tokenizer.apply_chat_template(
build_audio_chat_messages(prompt_text),
tokenize=False,
add_generation_prompt=True,
)
if isinstance(rendered, str) and rendered.strip():
return rendered
except Exception:
pass
return f"<|audio|>\n{prompt_text.strip()}"
def _get_generation_stop_ids(self) -> int | list[int] | None:
tokenizer = getattr(self.processor, "tokenizer", None)
if tokenizer is None:
return None
stop_ids: list[int] = []
for token_name in ("eos_token_id",):
token_id = getattr(tokenizer, token_name, None)
if isinstance(token_id, int) and token_id >= 0:
stop_ids.append(token_id)
if hasattr(tokenizer, "convert_tokens_to_ids"):
for token in ("<turn|>", "<eos>"):
try:
token_id = tokenizer.convert_tokens_to_ids(token)
except Exception:
token_id = None
if isinstance(token_id, int) and token_id >= 0:
stop_ids.append(token_id)
deduped: list[int] = []
for token_id in stop_ids:
if token_id not in deduped:
deduped.append(token_id)
if not deduped:
return None
if len(deduped) == 1:
return deduped[0]
return deduped
def _get_pad_token_id(self) -> int | None:
tokenizer = getattr(self.processor, "tokenizer", None)
if tokenizer is None:
return None
pad_token_id = getattr(tokenizer, "pad_token_id", None)
if isinstance(pad_token_id, int) and pad_token_id >= 0:
return pad_token_id
eos_token_id = getattr(tokenizer, "eos_token_id", None)
if isinstance(eos_token_id, int) and eos_token_id >= 0:
return eos_token_id
return None
def _build_audio_inputs(self, prompt_text: str, chunk, sample_rate: int):
rendered_prompt = self._render_audio_prompt(prompt_text)
attempts = (
{"audio": chunk, "sampling_rate": sample_rate},
{"audio": [chunk], "sampling_rate": sample_rate},
)
last_error = None
for extra_kwargs in attempts:
try:
return (
self.processor(
text=rendered_prompt,
return_tensors="pt",
padding=True,
**extra_kwargs,
),
rendered_prompt,
)
except Exception as exc:
last_error = exc
if last_error is not None:
raise last_error
raise RuntimeError("Failed to build audio inputs.")
def _run_audio_transcription_attempt(
self,
*,
prompt_text: str,
chunk,
sample_rate: int,
max_new_tokens: int,
do_sample: bool,
) -> tuple[str, str]:
input_ids, rendered_prompt = self._build_audio_inputs(
prompt_text, chunk, sample_rate
)
input_ids = self._move_inputs(input_ids)
outputs = None
response = ""
fallback = ""
try:
generation_kwargs = {
"max_new_tokens": max_new_tokens,
"do_sample": do_sample,
"use_cache": False,
"repetition_penalty": 1.05,
}
eos_token_id = self._get_generation_stop_ids()
if eos_token_id is not None:
generation_kwargs["eos_token_id"] = eos_token_id
pad_token_id = self._get_pad_token_id()
if pad_token_id is not None:
generation_kwargs["pad_token_id"] = pad_token_id
if do_sample:
generation_kwargs["temperature"] = DEFAULT_TEMPERATURE
generation_kwargs["top_p"] = DEFAULT_TOP_P
generation_kwargs["top_k"] = DEFAULT_TOP_K
with self.torch.inference_mode():
outputs = self.model.generate(
**input_ids,
**generation_kwargs,
)
input_length = input_ids["input_ids"].shape[-1]
response = self.processor.decode(
outputs[0][input_length:],
skip_special_tokens=False,
)
decoded = ""
if hasattr(self.processor, "parse_response"):
try:
decoded = normalize_parsed_response(
self.processor.parse_response(response)
)
except Exception:
decoded = ""
if not decoded.strip():
decoded = response.strip()
if not decoded.strip():
fallback = self.processor.decode(
outputs[0][input_length:],
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
).strip()
decoded = fallback
return strip_special_only_response(decoded), rendered_prompt
finally:
self._cleanup_generation_tensors(
input_ids,
outputs,
response,
fallback,
)
def _generate_from_inputs(
self,
inputs,
max_new_tokens: int,
*,
do_sample: bool,
use_parse_response: bool,
) -> str:
with self.torch.inference_mode():
output_ids = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=DEFAULT_TEMPERATURE if do_sample else None,
top_p=DEFAULT_TOP_P if do_sample else None,
top_k=DEFAULT_TOP_K if do_sample else None,
)
prompt_length = inputs["input_ids"].shape[1]
generated_ids = output_ids[:, prompt_length:]
raw_text = self.processor.decode(
generated_ids[0],
skip_special_tokens=False,
)
if use_parse_response and hasattr(self.processor, "parse_response"):
try:
parsed = self.processor.parse_response(raw_text)
parsed_text = normalize_parsed_response(parsed)
if parsed_text:
return parsed_text
except Exception:
pass
return self.processor.decode(
generated_ids[0],
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
).strip()
def _run_chat_template_request(
self,
messages: list[dict[str, object]],
*,
max_new_tokens: int,
do_sample: bool = True,
) -> str:
inputs = None
outputs = None
response = ""
fallback = ""
try:
inputs = self.processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
)
inputs = self._move_inputs(inputs)
input_len = inputs["input_ids"].shape[-1]
generation_kwargs = {
"max_new_tokens": max_new_tokens,
"do_sample": do_sample,
"use_cache": False,
}
eos_token_id = self._get_generation_stop_ids()
if eos_token_id is not None:
generation_kwargs["eos_token_id"] = eos_token_id
pad_token_id = self._get_pad_token_id()
if pad_token_id is not None:
generation_kwargs["pad_token_id"] = pad_token_id
if do_sample:
generation_kwargs["temperature"] = DEFAULT_TEMPERATURE
generation_kwargs["top_p"] = DEFAULT_TOP_P
generation_kwargs["top_k"] = DEFAULT_TOP_K
with self.torch.inference_mode():
outputs = self.model.generate(
**inputs,
**generation_kwargs,
)
response = self.processor.decode(
outputs[0][input_len:],
skip_special_tokens=False,
)
if hasattr(self.processor, "parse_response"):
try:
parsed = self.processor.parse_response(response)
parsed_text = normalize_parsed_response(parsed)
if strip_special_only_response(parsed_text):
return parsed_text
except Exception:
pass
stripped_response = response.strip()
if strip_special_only_response(stripped_response):
return stripped_response
fallback = self.processor.decode(
outputs[0][input_len:],
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
).strip()
return strip_special_only_response(fallback)
finally:
self._cleanup_generation_tensors(inputs, outputs, response, fallback)
def _run_chat_template_request_debug(
self,
messages: list[dict[str, object]],
*,
max_new_tokens: int,
do_sample: bool = False,
) -> tuple[str, str]:
inputs = None
outputs = None
response = ""
fallback = ""
try:
inputs = self.processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
)
inputs = self._move_inputs(inputs)
input_len = inputs["input_ids"].shape[-1]
generation_kwargs = {
"max_new_tokens": max_new_tokens,
"do_sample": do_sample,
"use_cache": False,
}
if do_sample:
generation_kwargs["temperature"] = DEFAULT_TEMPERATURE
generation_kwargs["top_p"] = DEFAULT_TOP_P
generation_kwargs["top_k"] = DEFAULT_TOP_K
with self.torch.inference_mode():
outputs = self.model.generate(
**inputs,
**generation_kwargs,
)
response = self.processor.decode(
outputs[0][input_len:],
skip_special_tokens=False,
).strip()
parsed_text = ""
if hasattr(self.processor, "parse_response"):
try:
parsed = self.processor.parse_response(response)
parsed_text = normalize_parsed_response(parsed)
except Exception:
parsed_text = ""
if not parsed_text:
fallback = self.processor.decode(
outputs[0][input_len:],
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
).strip()
parsed_text = fallback
return response, parsed_text.strip()
finally:
self._cleanup_generation_tensors(inputs, outputs, response, fallback)
def _run_audio_chat_template_request(
self,
messages: list[dict[str, object]],
*,
max_new_tokens: int,
) -> str:
inputs = None
outputs = None
response = ""
fallback = ""
try:
inputs = self.processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
)
inputs = self._move_inputs(inputs)
input_len = inputs["input_ids"].shape[-1]
with self.torch.inference_mode():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
)
response = self.processor.decode(
outputs[0][input_len:],
skip_special_tokens=False,
)
if hasattr(self.processor, "parse_response"):
try:
parsed = self.processor.parse_response(response)
parsed_text = normalize_parsed_response(parsed)
if parsed_text:
return parsed_text
except Exception:
pass
fallback = self.processor.decode(
outputs[0][input_len:],
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
).strip()
return fallback or response.strip()
finally:
self._cleanup_generation_tensors(inputs, outputs, response, fallback)
def _run_text_chat_template_request(
self,
messages: list[dict[str, str]],
*,
max_new_tokens: int,
enable_thinking: bool,
) -> ChatResult:
if not hasattr(self.processor, "apply_chat_template"):
raise RuntimeError("The current processor does not support chat templates.")
inputs = None
outputs = None
response = ""
try:
try:
text = self.processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=enable_thinking,
)
except TypeError:
text = self.processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
inputs = self.processor(
text=text,
return_tensors="pt",
)
inputs = self._move_inputs(inputs)
input_len = inputs["input_ids"].shape[-1]
with self.torch.inference_mode():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
)
response = self.processor.decode(
outputs[0][input_len:],
skip_special_tokens=False,
)
if hasattr(self.processor, "parse_response"):
try:
parsed = self.processor.parse_response(response)
return parse_chat_result(parsed, response)
except Exception:
pass
fallback = self.processor.decode(
outputs[0][input_len:],
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
).strip()
return ChatResult(
content=fallback or response.strip(),
thinking="",
raw_text=response.strip(),
)
finally:
self._cleanup_generation_tensors(inputs, outputs, response)
def _write_chunk_wav(self, path: Path, waveform, sample_rate: int) -> None:
pcm = self.np.clip(waveform, -1.0, 1.0)
pcm = (pcm * 32767).astype(self.np.int16)
with wave.open(str(path), "wb") as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2)
wav_file.setframerate(sample_rate)
wav_file.writeframes(pcm.tobytes())
def chat(
self,
messages: list[dict[str, str]],
max_new_tokens: int = 256,
system_prompt: str = "You are a helpful assistant.",
enable_thinking: bool = False,
) -> ChatResult:
self._ensure_causal_model()
if not messages:
raise RuntimeError("Chat messages are empty.")
if getattr(self.processor, "chat_template", None) is not None:
result = self._run_text_chat_template_request(
build_text_chat_template_messages(system_prompt, messages),
max_new_tokens=max_new_tokens,
enable_thinking=enable_thinking,
)
else:
prompt_text = build_chat_prompt(system_prompt, messages)
inputs = self.processor(
text=prompt_text,
add_special_tokens=True,
return_tensors="pt",
)
inputs = self._move_inputs(inputs)
text = self._generate_from_inputs(
inputs,
max_new_tokens,
do_sample=False,
use_parse_response=True,
).strip()
result = ChatResult(content=text, thinking="", raw_text=text)
return ChatResult(
content=clean_chat_text(result.content),
thinking=result.thinking.strip(),
raw_text=result.raw_text,
)
def analyze_image(
self,
image_path: Path,
*,
prompt: str,
language: str,
max_new_tokens: int = 256,
) -> str:
self._ensure_multimodal_model()
if not image_path.is_file():
raise RuntimeError(f"Image file not found: {image_path}")
if getattr(self.processor, "chat_template", None) is None:
raise RuntimeError(
"This image flow requires a processor with a chat template. "
"Use a Gemma 4 instruction-tuned checkpoint such as "
"`google/gemma-4-E2B-it` or a local instruction-tuned export."
)
prompt_text = build_image_prompt(language, prompt)
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image_path.as_posix()},
{"type": "text", "text": prompt_text},
],
}
]
input_ids = self.processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
)
input_ids = self._move_inputs(input_ids)
raw_text = self._generate_from_inputs(
input_ids,
max_new_tokens=max_new_tokens,
do_sample=True,
use_parse_response=True,
)
return raw_text.strip()
def analyze_video(
self,
video_path: Path,
*,
prompt: str,
language: str,
max_new_tokens: int = 256,
) -> str:
self._ensure_multimodal_model()
if not video_path.is_file():
raise RuntimeError(f"Video file not found: {video_path}")
if not is_video_file(video_path):
raise RuntimeError(f"Unsupported video file: {video_path}")
chunk_results = self.analyze_video_chunks(
video_path,
prompt=prompt,
language=language,
max_new_tokens=max_new_tokens,
)
chunk_texts = [
result.cleaned_text or result.raw_text
for result in chunk_results
if (result.cleaned_text or result.raw_text).strip()
]
if not chunk_texts:
return ""
if getattr(self.processor, "chat_template", None) is None:
raise RuntimeError(
"This video flow requires a processor with a chat template. "
"Use a Gemma 4 instruction-tuned checkpoint such as "
"`google/gemma-4-E2B-it` or a local instruction-tuned export."
)
prompt_text = build_video_summary_prompt(language, chunk_texts)
messages = [
{
"role": "user",
"content": [{"type": "text", "text": prompt_text}],
}
]
return self._run_chat_template_request(
messages,
max_new_tokens=min(max_new_tokens, 192),
do_sample=False,
).strip()
def analyze_video_chunks(
self,
video_path: Path,
*,
prompt: str,
language: str,
max_new_tokens: int = 256,
chunk_seconds: int = DEFAULT_VIDEO_CHUNK_SECONDS,
) -> list[VideoAnalysisChunkResult]:
self._ensure_multimodal_model()
if not video_path.is_file():
raise RuntimeError(f"Video file not found: {video_path}")
if not is_video_file(video_path):
raise RuntimeError(f"Unsupported video file: {video_path}")
if getattr(self.processor, "chat_template", None) is None:
raise RuntimeError(
"This video flow requires a processor with a chat template. "
"Use a Gemma 4 instruction-tuned checkpoint such as "
"`google/gemma-4-E2B-it` or a local instruction-tuned export."
)
total_duration = get_video_duration_seconds(
video_path,
self.settings.ffmpeg_path,
)
if total_duration is None or total_duration <= 0:
total_duration = float(chunk_seconds)
results: list[VideoAnalysisChunkResult] = []
with TemporaryDirectory(prefix="gemma4_video_") as temp_dir:
chunk_index = 0
start_seconds = 0.0
while start_seconds < total_duration:
end_seconds = min(start_seconds + chunk_seconds, total_duration)
segment_path = Path(temp_dir) / f"segment_{chunk_index + 1:03d}.mp4"
split_video_segment(
video_path,
segment_path,
start_time=start_seconds,
duration_seconds=max(end_seconds - start_seconds, 1.0),
ffmpeg_path=self.settings.ffmpeg_path,
)
prompt_text = build_video_chunk_prompt(
language,
prompt,
start_seconds,
end_seconds,
)
messages = [
{
"role": "user",
"content": [
{"type": "video", "video": segment_path.as_posix()},
{"type": "text", "text": prompt_text},
],
}
]
raw_text, parsed_text = self._run_chat_template_request_debug(
messages,
max_new_tokens=min(max_new_tokens, 256),
do_sample=False,
)
results.append(
VideoAnalysisChunkResult(
chunk_index=chunk_index,
start_seconds=start_seconds,
end_seconds=end_seconds,
prompt_text=prompt_text,
raw_text=raw_text,
cleaned_text=parsed_text,
)
)
chunk_index += 1
start_seconds = end_seconds
return results
def transcribe_chunks(
self,
audio_path: Path,
progress_callback: Callable[[str], None] | None = None,
chunk_result_callback: Callable[[ChunkTranscriptionResult], None] | None = None,
) -> list[ChunkTranscriptionResult]:
self._ensure_multimodal_model()
if not audio_path.is_file():
raise RuntimeError(f"Media file not found: {audio_path}")
if getattr(self.processor, "chat_template", None) is None:
raise RuntimeError(
"This transcription flow requires a processor with a chat template. "
"Use a Gemma 4 instruction-tuned checkpoint such as "
"`google/gemma-4-E2B-it` or a local instruction-tuned export."
)
with TemporaryDirectory(prefix="gemma4_audio_") as temp_dir:
source_audio_path = audio_path
if is_video_file(audio_path):
if progress_callback is not None:
progress_callback("Extracting audio from video...")
extracted_audio_path = Path(temp_dir) / "extracted_audio.wav"
extract_audio_from_video(
audio_path,
extracted_audio_path,
self.settings.sample_rate,
self.settings.ffmpeg_path,
)
source_audio_path = extracted_audio_path
waveform, sample_rate = load_audio(
self.librosa, source_audio_path, self.settings.sample_rate
)
chunk_size = sample_rate * self.settings.chunk_seconds
prompt = build_transcription_prompt(
self.settings.language,
self.settings.mode,
)
chunks = [
chunk for _, chunk in iter_chunks(waveform, chunk_size) if len(chunk) > 0
]
if not chunks:
return []
self._cleanup_generation_tensors()
chunk_temp_dir = Path(temp_dir) / "chunks"
chunk_temp_dir.mkdir(parents=True, exist_ok=True)
if progress_callback is not None:
progress_callback(f"Preparing {len(chunks)} chunk(s)...")
results: list[ChunkTranscriptionResult] = []
for chunk_index, chunk in enumerate(chunks, start=1):
if progress_callback is not None:
progress_callback(
f"Generating chunk {chunk_index}/{len(chunks)}..."
)
chunk_wav_path = chunk_temp_dir / f"chunk_{chunk_index:04d}.wav"
self._write_chunk_wav(chunk_wav_path, chunk, sample_rate)
prompt_text = build_transcription_prompt(
self.settings.language,
self.settings.mode,
)
messages = build_audio_transcription_messages(
chunk_wav_path.as_posix(),
prompt_text,
)
decoded = self._run_audio_chat_template_request(
messages,
max_new_tokens=min(self.settings.max_new_tokens, 256),
)
cleaned_text = clean_transcript_text(decoded, prompt)
final_text = cleaned_text or salvage_transcript_text(decoded, prompt)
final_text = normalize_transcript_output(
final_text,
self.settings.language,
self.settings.mode,
)
final_text = dedupe_repeated_phrases(final_text)
if is_low_quality_transcript(final_text):
final_text = ""
results.append(
ChunkTranscriptionResult(
chunk_index=chunk_index - 1,
prompt_text=prompt_text,
raw_text=decoded,
cleaned_text=final_text,
)
)
if chunk_result_callback is not None:
chunk_result_callback(results[-1])
self._cleanup_generation_tensors()
return results
def transcribe(
self,
audio_path: Path,
progress_callback: Callable[[str], None] | None = None,
) -> str:
chunk_results = self.transcribe_chunks(
audio_path,
progress_callback=progress_callback,
)
transcript_parts = [
result.cleaned_text for result in chunk_results if result.cleaned_text
]
return "\n".join(transcript_parts).strip()
def unload(self) -> None:
model = getattr(self, "model", None)
torch_module = getattr(self, "torch", None)
if model is not None:
try:
model.cpu()
except Exception:
pass
del self.model
if hasattr(self, "processor"):
del self.processor
gc.collect()
if torch_module is not None and torch_module.cuda.is_available():
torch_module.cuda.empty_cache()
try:
torch_module.cuda.ipc_collect()
except Exception:
pass
from __future__ import annotations
import argparse
import sys
from pathlib import Path
from transcriber_core import DEFAULT_CHUNK_SECONDS, DEFAULT_MODEL_DIR
from transcriber_core import GemmaTranscriber, TranscriberSettings
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Transcribe an audio or video file with a local Gemma model."
)
parser.add_argument(
"audio_file",
type=Path,
help="Path to the audio or video file.",
)
parser.add_argument(
"--model-dir",
type=Path,
default=DEFAULT_MODEL_DIR,
help=f"Path to the local Gemma model. Default: {DEFAULT_MODEL_DIR}",
)
parser.add_argument(
"--language",
default="ja",
help="Target transcript language, for example ja or en. Default: ja",
)
parser.add_argument(
"--mode",
choices=["speech", "song"],
default="speech",
help="Transcription mode. Use song for singing vocals. Default: speech",
)
parser.add_argument(
"--chunk-seconds",
type=int,
default=DEFAULT_CHUNK_SECONDS,
help="Audio chunk length in seconds. Default: 30",
)
parser.add_argument(
"--max-new-tokens",
type=int,
default=256,
help="Maximum tokens to generate for each chunk. Default: 256",
)
parser.add_argument(
"--dtype",
choices=["auto", "bfloat16", "float16", "float32"],
default="auto",
help="Torch dtype for model loading. Default: auto",
)
parser.add_argument(
"--device-map",
default="auto",
help="Transformers device_map value. Default: auto",
)
parser.add_argument(
"--max-memory",
default='{"cuda:0":"10GiB","cuda:1":"10GiB","cpu":"32GiB"}',
help='Optional max_memory JSON for multi-GPU loading. '
'Example: {"cuda:0":"10GiB","cuda:1":"10GiB","cpu":"32GiB"}',
)
parser.add_argument(
"--ffmpeg-path",
default="",
help="Optional full path to ffmpeg.exe for video transcription.",
)
parser.add_argument(
"--output",
type=Path,
help="Optional output text file path.",
)
return parser.parse_args()
def main() -> int:
args = parse_args()
settings = TranscriberSettings(
model_dir=args.model_dir,
language=args.language,
mode=args.mode,
chunk_seconds=args.chunk_seconds,
max_new_tokens=args.max_new_tokens,
dtype=args.dtype,
device_map=args.device_map,
max_memory=args.max_memory,
ffmpeg_path=args.ffmpeg_path,
)
try:
print("Loading model and processor...", file=sys.stderr)
transcriber = GemmaTranscriber(settings)
final_text = transcriber.transcribe(
args.audio_file,
progress_callback=lambda message: print(message, file=sys.stderr),
)
except RuntimeError as exc:
raise SystemExit(str(exc)) from exc
print(final_text)
if args.output:
args.output.write_text(final_text, encoding="utf-8")
print(f"Saved transcript to: {args.output}", file=sys.stderr)
return 0
if __name__ == "__main__":
raise SystemExit(main())
コメント