第 7 章 Web UI(app.py) 本章目标:用 Gradio 把推理包成一个浏览器可访问的交互界面。 7.1 章节地图 app.py 干三件事: 7.2 完整源码 整个文件不到 110 行,但麻雀虽小五脏俱全。 7.3 Gradio 的核心抽象 Gradio 把 ML demo 简化为「输入组件 → 回调函数 → 输出组件」三件套。 最小例子 3 行就能起一个网页。本项目用更灵活的 自定义布局,但本质一样。
本章目标:用 Gradio 把推理包成一个浏览器可访问的交互界面。
app.py 干三件事:
┌─────────────────────────────────────────────────────────┐ │ 1. 全局 generator 惰性加载 + 缓存(避免重载) │ │ 2. predict() 回调 异常兜底 + 调用 generator │ │ 3. build_ui() Gradio Blocks 自定义布局 │ └─────────────────────────────────────────────────────────┘
""" app.py ====== 基于 Gradio 的 Web UI 入口。 提供: - 文本输入框:输入提示词。 - 滑块:max_tokens / temperature / top_k。 - 输出文本框:显示模型生成结果。 后端复用 inference.py 中的 GPTGenerator.generate()。 """ import argparse import os import gradio as gr from inference import load_for_inference, GPTGenerator # 全局 generator,避免每次请求都重新加载模型。 _GENERATOR: GPTGenerator | None = None def get_generator(checkpoint: str) -> GPTGenerator: """惰性加载并缓存全局 generator。""" global _GENERATOR if _GENERATOR is None: print(f"[app] 加载模型: {checkpoint}") _GENERATOR = load_for_inference(checkpoint) return _GENERATOR def predict(prompt: str, max_tokens: int, temperature: float, top_k: int, checkpoint: str) -> str: """Gradio 回调函数:调用模型生成文本。""" try: gen = get_generator(checkpoint) result = gen.generate( prompt=prompt, max_new_tokens=int(max_tokens), temperature=float(temperature), top_k=int(top_k), ) return result except Exception as e: # noqa: BLE001 return f"[生成失败] {type(e).__name__}: {e}" def build_ui(default_checkpoint: str) -> gr.Blocks: """构建 Gradio 界面。""" with gr.Blocks(title="GPT 文本生成器") as demo: gr.Markdown("# 🤖 GPT 文本生成器\n基于小型 GPT-2 的自回归文本生成 Demo。") with gr.Row(): with gr.Column(scale=3): prompt_input = gr.Textbox( label="提示词 (Prompt)", placeholder="例如: To be, or not to be...", lines=4, ) generate_btn = gr.Button("生成", variant="primary") output_box = gr.Textbox(label="生成结果", lines=12) with gr.Column(scale=2): max_tokens_slider = gr.Slider( minimum=10, maximum=500, value=100, step=10, label="最大生成长度 (max_tokens)", ) temperature_slider = gr.Slider( minimum=0.0, maximum=1.5, value=0.8, step=0.05, label="温度 (temperature)", ) top_k_slider = gr.Slider( minimum=0, maximum=100, value=40, step=1, label="Top-K", ) checkpoint_box = gr.Textbox( value=default_checkpoint, label="Checkpoint 路径", ) # 点击按钮触发生成。 generate_btn.click( fn=predict, inputs=[prompt_input, max_tokens_slider, temperature_slider, top_k_slider, checkpoint_box], outputs=output_box, ) return demo def main(): parser = argparse.ArgumentParser(description="启动 GPT Gradio Web UI") parser.add_argument( "--checkpoint", type=str, default="checkpoints/gpt_final.pt", help="默认加载的 checkpoint 路径", ) parser.add_argument("--host", type=str, default="0.0.0.0", help="监听地址") parser.add_argument("--port", type=int, default=7860, help="监听端口") parser.add_argument("--share", action="store_true", help="是否生成公网链接") args = parser.parse_args() demo = build_ui(default_checkpoint=args.checkpoint) demo.launch(server_name=args.host, server_port=args.port, share=args.share) if __name__ == "__main__": main()
整个文件不到 110 行,但麻雀虽小五脏俱全。
Gradio 把 ML demo 简化为「输入组件 → 回调函数 → 输出组件」三件套。
import gradio as gr def predict(input_text): return "模型输出:" + input_text demo = gr.Interface(fn=predict, inputs="text", outputs="text") demo.launch()
3 行就能起一个网页。本项目用更灵活的 gr.Blocks 自定义布局,但本质一样。
| 方式 | 优点 | 适用 |
|---|---|---|
gr.Interface |
一行起界面,超简单 | 简单单输入单输出 |
gr.Blocks |
灵活布局,多组件 | 本项目(要放滑块、文本框、按钮组合) |
本项目选 Blocks,因为要同时展示「prompt 输入、3 个滑块、checkpoint 输入、输出」。
_GENERATOR: GPTGenerator | None = None def get_generator(checkpoint: str) -> GPTGenerator: global _GENERATOR if _GENERATOR is None: print(f"[app] 加载模型: {checkpoint}") _GENERATOR = load_for_inference(checkpoint) return _GENERATOR
模型加载慢(读磁盘、初始化、迁 GPU),不能每次点「生成」按钮都重载一次。用模块级全局变量缓存:
这是「惰性初始化 + 单例模式」的经典实现。
_GENERATOR: GPTGenerator | None = None
类型注解 GPTGenerator | None 是 Python 3.10+ 语法(等价于 Optional[GPTGenerator])。初始为 None。
global _GENERATOR
函数内要修改模块级变量必须 global 声明,否则 Python 会把它当成局部变量。
单例模式在多 worker(如 gunicorn -w 4)下会有 4 份模型副本,吃显存。生产中常用:
functools.lru_cache 包装。本项目单进程 demo.launch() 足够,不用想这么复杂。
def predict(prompt, max_tokens, temperature, top_k, checkpoint): try: gen = get_generator(checkpoint) result = gen.generate( prompt=prompt, max_new_tokens=int(max_tokens), temperature=float(temperature), top_k=int(top_k), ) return result except Exception as e: return f"[生成失败] {type(e).__name__}: {e}"
Gradio 默认会把未捕获异常显示成 500 错误页,用户体验差。这里捕获后把错误信息塞进输出框,用户能立刻看到原因(如 checkpoint 不存在)。
max_new_tokens=int(max_tokens), temperature=float(temperature), top_k=int(top_k),
Gradio 滑块传过来的值虽然应该是数字,但显式 int() / float() 转换是防御式编程——万一前端传了字符串 "40" 也能处理。
return f"[生成失败] {type(e).__name__}: {e}"
返回「异常类型 + 消息」,如:
[生成失败] FileNotFoundError: checkpoints/gpt_final.pt 不存在 [生成失败] RuntimeError: CUDA out of memory
比单纯返回 str(e) 更有诊断价值(能看到异常类型)。
def build_ui(default_checkpoint): with gr.Blocks(title="GPT 文本生成器") as demo: gr.Markdown("# 🤖 GPT 文本生成器\n基于小型 GPT-2 的自回归文本生成 Demo。") with gr.Row(): # 横向分两列 with gr.Column(scale=3): # 左列占 3/5 宽 prompt_input = gr.Textbox(...) generate_btn = gr.Button("生成", variant="primary") output_box = gr.Textbox(...) with gr.Column(scale=2): # 右列占 2/5 宽 max_tokens_slider = gr.Slider(...) temperature_slider = gr.Slider(...) top_k_slider = gr.Slider(...) checkpoint_box = gr.Textbox(...) generate_btn.click( fn=predict, inputs=[prompt_input, max_tokens_slider, temperature_slider, top_k_slider, checkpoint_box], outputs=output_box, ) return demo
┌──────────────────────────────────────────────────┐ │ # 🤖 GPT 文本生成器 │ ├────────────────────────────┬─────────────────────┤ │ 提示词 (Prompt) │ 最大生成长度 ───●── │ │ ┌────────────────────────┐ │ 温度 ──●─── │ │ │ To be, or │ │ Top-K ──●─── │ │ │ │ │ Checkpoint 路径 │ │ └────────────────────────┘ │ ┌─────────────────┐ │ │ [生成] │ │checkpoints/... │ │ │ 生成结果 │ └─────────────────┘ │ │ ┌────────────────────────┐ │ │ │ │ │ │ │ │ │ │ │ │ │ └────────────────────────┘ │ │ └────────────────────────────┴─────────────────────┘
gr.Blocks:容器with gr.Blocks(title="GPT 文本生成器") as demo: ...
Blocks 是一个声明式容器,with 块内定义的所有组件自动归到这个 demo 里。
gr.Markdown:富文本gr.Markdown("# 🤖 GPT 文本生成器\n...")
支持 Markdown 语法,用于标题、说明。比纯文本好看。
gr.Row / gr.Column:栅格布局with gr.Row(): # 一行 with gr.Column(scale=3): # 左列,占 3 份 ... with gr.Column(scale=2): # 右列,占 2 份 ...
scale 参数控制列宽比例(3:2 = 60%:40%)。
gr.Textbox:文本框prompt_input = gr.Textbox( label="提示词 (Prompt)", placeholder="例如: To be, or not to be...", lines=4, # 显示 4 行高 )
label:上方标签。placeholder:空时灰色提示文字。lines:高度(行数)。gr.Slider:滑块max_tokens_slider = gr.Slider( minimum=10, maximum=500, value=100, step=10, label="最大生成长度 (max_tokens)", )
minimum/maximum:范围。value:初始值。step:步长(每拖一格变化多少)。label:标签。gr.Button:按钮generate_btn = gr.Button("生成", variant="primary")
variant="primary" 让按钮高亮(绿色/蓝色,看主题),强调主要操作。
generate_btn.click( fn=predict, inputs=[prompt_input, max_tokens_slider, temperature_slider, top_k_slider, checkpoint_box], outputs=output_box, )
把按钮点击事件绑定到 predict 函数:
inputs 列表里所有组件的值,作为参数传给 predict。predict 返回值赋给 outputs(这里 output_box)。inputs 顺序必须与 predict 参数顺序一致:
inputs = [prompt_input, max_tokens_slider, temperature_slider, top_k_slider, checkpoint_box] ↓ ↓ ↓ ↓ ↓ def predict(prompt, max_tokens, temperature, top_k, checkpoint):
def main(): parser = argparse.ArgumentParser(description="启动 GPT Gradio Web UI") parser.add_argument("--checkpoint", default="checkpoints/gpt_final.pt") parser.add_argument("--host", default="0.0.0.0") parser.add_argument("--port", default=7860) parser.add_argument("--share", action="store_true") args = parser.parse_args() demo = build_ui(default_checkpoint=args.checkpoint) demo.launch(server_name=args.host, server_port=args.port, share=args.share)
--hostparser.add_argument("--host", default="0.0.0.0")
监听地址:
0.0.0.0(默认):监听所有网卡,允许局域网其他机器访问。127.0.0.1:只本机访问(更安全)。--portparser.add_argument("--port", default=7860)
Gradio 默认 7860。被占用时换一个,如 --port 8080。
--shareparser.add_argument("--share", action="store_true")
调 Gradio 的 tunnel 服务,生成一个 xxx.gradio.live 公网链接,几小时内有效,适合给远端同事演示。
⚠️ 注意:开了 share 等于把你的模型暴露到公网,敏感场景别开。
--checkpointparser.add_argument("--checkpoint", default="checkpoints/gpt_final.pt")
默认加载最终模型。可指向任意中间 checkpoint:--checkpoint checkpoints/gpt_step2000.pt。
# 本机访问(默认) python app.py # 局域网访问 python app.py --host 0.0.0.0 --port 8080 # 公网临时链接 python app.py --share # 加载特定 checkpoint python app.py --checkpoint checkpoints/gpt_step2000.pt
启动后浏览器打开 本地端口 7860(或对应 host:port)即可。
学完基础,可以试这些:
让生成过程像 ChatGPT 一样逐字显示:
def predict(prompt, ...): for token in gen.generate_streaming(prompt, ...): yield token # 用 yield 而非 return
需要把 GPTGenerator.generate 改造成生成器,每生成一个 token yield 一次。
gr.Examples( examples=[ ["To be, or", 100, 0.8, 40], ["ROMEO:", 200, 0.9, 50], ["Once upon a time", 150, 0.7, 40], ], inputs=[prompt_input, max_tokens_slider, temperature_slider, top_k_slider], )
用户点一下就填入预设值,超方便。
clear_btn = gr.Button("清空") clear_btn.click(fn=lambda: ("", ""), inputs=None, outputs=[prompt_input, output_box])
chatbot = gr.Chatbot() # 维护对话历史,每次生成都带上历史
demo = gr.Blocks(css=".gradio-container {max-width: 800px; margin: auto;}")
加一个清空按钮:点击后清空 prompt 和 output。
clear_btn = gr.Button("清空") clear_btn.click(fn=lambda: ("", ""), outputs=[prompt_input, output_box])
加示例按钮:用 gr.Examples 提供几个预设 prompt。
改温度上限:把 temperature 滑块上限改成 2.0,看看高温生成的效果(可能会乱码)。
加生成耗时显示:在 predict 里记录开始时间,返回 f"{result}\n\n耗时: {elapsed:.2f}s"。
加历史记录:用一个 gr.State 维护生成历史,每次生成后追加。
思考题:如果两个用户同时点「生成」,会发生什么?
gr.Blocks 比 gr.Interface 更灵活,支持 Row/Column 栅格布局。_GENERATOR 实现惰性加载,避免每次请求重载模型。try/except,把错误塞进输出框而非崩溃。demo.launch(host, port, share) 控制监听地址和公网分享。UI 也通了,去 第 8 章 工程实践要点 看把散落各章的工程经验集中讲一遍。