使用 Python 构建交互式鸢尾花分类器 GUI
Source: Dev.to
使用 Python 构建交互式鸢尾花分类器 GUI
在本教程中,我们将使用 scikit‑learn 训练一个简单的鸢尾花(Iris)分类模型,并使用 tkinter 为它构建一个图形用户界面(GUI)。用户可以在界面中输入萼片和花瓣的长度/宽度,点击按钮后实时得到预测的花种类。
目录
准备工作
首先,确保已安装以下 Python 包:
pip install scikit-learn pandas matplotlib
提示:如果你已经有
tkinter(大多数 Python 安装默认自带),则无需额外安装。
加载数据并训练模型
我们使用 sklearn.datasets 中自带的 Iris 数据集。下面的代码完成数据划分、模型训练以及保存模型的步骤。
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
import joblib
# 加载数据
iris = load_iris()
X = iris.data # 四个特征:萼片长度、萼片宽度、花瓣长度、花瓣宽度
y = iris.target # 0: setosa, 1: versicolor, 2: virginica
feature_names = iris.feature_names
target_names = iris.target_names
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
# 训练随机森林分类器
clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(X_train, y_train)
# 打印模型在测试集上的准确率
print(f"Test accuracy: {clf.score(X_test, y_test):.2%}")
# 将模型保存到磁盘,后续 GUI 中直接加载
joblib.dump(clf, "iris_classifier.pkl")
说明
RandomForestClassifier只是一种示例,你也可以换成LogisticRegression、SVC等其他模型。joblib.dump用于序列化模型,方便在 GUI 中直接加载而无需每次重新训练。
构建 GUI
下面的代码使用 tkinter 创建一个简洁的窗口,包含四个 Entry(用于输入特征值)和一个 Button(用于触发预测)。预测结果会显示在标签 (Label) 中。
import tkinter as tk
from tkinter import ttk, messagebox
import joblib
import numpy as np
# 加载已经训练好的模型
model = joblib.load("iris_classifier.pkl")
target_names = ["setosa", "versicolor", "virginica"]
def predict():
try:
# 读取用户输入的四个特征值
sepal_length = float(entry_sepal_length.get())
sepal_width = float(entry_sepal_width.get())
petal_length = float(entry_petal_length.get())
petal_width = float(entry_petal_width.get())
except ValueError:
messagebox.showerror("输入错误", "请确保所有字段均为数字。")
return
# 将特征值组织成模型需要的形状 (1, 4)
sample = np.array([[sepal_length, sepal_width, petal_length, petal_width]])
pred = model.predict(sample)[0]
result_var.set(f"预测结果: {target_names[pred]}")
# 创建主窗口
root = tk.Tk()
root.title("鸢尾花分类器")
root.geometry("350x300")
root.resizable(False, False)
# 样式
style = ttk.Style()
style.configure("TLabel", font=("Helvetica", 10))
style.configure("TButton", font=("Helvetica", 10, "bold"))
# 输入框标签和控件
ttk.Label(root, text="萼片长度 (cm):").grid(row=0, column=0, padx=10, pady=5, sticky="e")
entry_sepal_length = ttk.Entry(root, width=15)
entry_sepal_length.grid(row=0, column=1, padx=10, pady=5)
ttk.Label(root, text="萼片宽度 (cm):").grid(row=1, column=0, padx=10, pady=5, sticky="e")
entry_sepal_width = ttk.Entry(root, width=15)
entry_sepal_width.grid(row=1, column=1, padx=10, pady=5)
ttk.Label(root, text="花瓣长度 (cm):").grid(row=2, column=0, padx=10, pady=5, sticky="e")
entry_petal_length = ttk.Entry(root, width=15)
entry_petal_length.grid(row=2, column=1, padx=10, pady=5)
ttk.Label(root, text="花瓣宽度 (cm):").grid(row=3, column=0, padx=10, pady=5, sticky="e")
entry_petal_width = ttk.Entry(root, width=15)
entry_petal_width.grid(row=3, column=1, padx=10, pady=5)
# 预测按钮
ttk.Button(root, text="预测", command=predict).grid(row=4, column=0, columnspan=2, pady=15)
# 结果显示
result_var = tk.StringVar(value="预测结果: —")
ttk.Label(root, textvariable=result_var, foreground="blue", font=("Helvetica", 12, "bold")).grid(
row=5, column=0, columnspan=2, pady=10
)
root.mainloop()
代码要点说明
| 部分 | 作用 |
|---|---|
joblib.load("iris_classifier.pkl") | 从磁盘读取已经训练好的模型,避免每次启动 GUI 时重新训练。 |
predict() 函数 | 读取用户输入、进行类型检查、调用模型的 predict 方法并更新结果标签。 |
ttk.Entry 与 ttk.Label | 使用 ttk(主题化小部件)让界面更美观。 |
messagebox.showerror | 当用户输入非数字时弹出错误提示,提升用户体验。 |
完整代码示例
将 模型训练代码 与 GUI 代码 分别保存为 train_model.py 与 iris_gui.py,执行顺序如下:
python train_model.py # 生成 iris_classifier.pkl
python iris_gui.py # 启动交互式 GUI
小技巧:如果你想一次性运行所有代码,可以把两段代码合并到同一个脚本中,只在第一次运行时训练模型,后续直接加载。
运行与演示
-
运行训练脚本
python train_model.py终端会输出类似
Test accuracy: 96.67%,并在当前目录生成iris_classifier.pkl。 -
启动 GUI
python iris_gui.py会弹出如下窗口:

(示例截图,仅供参考) -
使用
- 在四个输入框中填写萼片/花瓣的长度与宽度(单位:厘米)。
- 点击 预测 按钮,标签会显示预测的花种类(
setosa、versicolor或virginica)。
小结
- 机器学习部分:使用
scikit-learn快速加载 Iris 数据集并训练随机森林模型。 - 模型持久化:通过
joblib将模型保存为二进制文件,便于在 GUI 中直接加载。 - GUI 实现:利用
tkinter(Python 标准库)构建简洁的交互界面,实现实时预测。 - 可扩展性:你可以替换模型、添加特征可视化(如
matplotlib)或使用更高级的 GUI 框架(如PyQt、Kivy)来提升用户体验。
祝你玩得开心,尽情探索机器学习与桌面应用的结合吧!
我们将使用
- scikit‑learn – 机器学习
- pandas – CSV 处理
- tkinter + ttkbootstrap – GUI
- 可选: tkinterdnd2 – 拖拽支持
您可以在此处克隆完整的仓库:
🔗 Iris-Flower-Classifier-GUI on GitHub
安装依赖
pip install pandas scikit-learn ttkbootstrap
# Optional: for drag & drop
pip install tkinterdnd2
注意:
tkinter在大多数系统上随 Python 预装。
导入库
import os, sys, threading
import pandas as pd
import tkinter as tk
from tkinter import filedialog, messagebox, ttk
import ttkbootstrap as tb
from ttkbootstrap.constants import *
# Optional drag & drop
try:
from tkinterdnd2 import TkinterDnD, DND_FILES
DND_ENABLED = True
except ImportError:
DND_ENABLED = False
print("Drag & Drop requires tkinterdnd2: pip install tkinterdnd2")
这里我们导入:
pandas用于处理 CSV 文件tkinter与ttkbootstrap用于现代 GUI 元素tkinterdnd2(如果需要)用于拖拽 CSV 支持
创建 Iris 机器学习模型
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
class IrisModel:
def __init__(self):
data = load_iris()
self.X = data.data
self.y = data.target
self.target_names = data.target_names
self.scaler = StandardScaler()
# Scale features
X_scaled = self.scaler.fit_transform(self.X)
# Train/test split
X_train, X_test, y_train, y_test = train_test_split(
X_scaled, self.y, test_size=0.2, random_state=42
)
# Random Forest classifier
self.clf = RandomForestClassifier(n_estimators=100, random_state=42)
self.clf.fit(X_train, y_train)
def predict(self, X):
X_scaled = self.scaler.transform(X)
preds = self.clf.predict(X_scaled)
return [self.target_names[p] for p in preds]
✅ 说明
- 对特征进行标准化,以提升模型性能。
- 随机森林分类器对初学者友好,并且在此分类任务上表现良好。
CSV 处理工作者
此工作者让 GUI 能够在不冻结界面的情况下处理多个 CSV 文件。
class ClassifierWorker:
def __init__(self, files, callbacks):
self.files = files
self.callbacks = callbacks
self._running = True
self.model = IrisModel()
def stop(self):
self._running = False
def run(self):
total = len(self.files)
for i, file in enumerate(self.files):
if not self._running:
break
try:
df = pd.read_csv(file)
required = {"sepal_length", "sepal_width", "petal_length", "petal_width"}
if set(df.columns) >= required:
X = df[["sepal_length", "sepal_width",
"petal_length", "petal_width"]].values
preds = self.model.predict(X)
if "found" in self.callbacks:
self.callbacks["found"](file, preds)
else:
if "found" in self.callbacks:
self.callbacks["found"](file, ["Error: Missing required columns"])
except Exception as e:
if "found" in self.callbacks:
self.callbacks["found"](file, [f"Error: {str(e)}"])
# Progress update
if "progress" in self.callbacks:
self.callbacks["progress"](int((i + 1) / total * 100))
# Finished
if "finished" in self.callbacks:
self.callbacks["finished"]()
✅ 说明
- 逐个读取 CSV 文件。
- 使用 Iris 模型预测物种。
- 异步向 GUI 发送进度更新。
构建 GUI
主应用程序类
class IrisClassifierApp:
def __init__(self):
if DND_ENABLED:
self.root = TkinterDnD.Tk()
else:
self.root = tb.Window(themename="darkly")
self.root.title("IrisClassifier v1.1")
self.root.minsize(1000, 700)
self.worker_obj = None
self.file_set = set()
self.model = IrisModel()
self._build_ui()
self._apply_styles()
- 检查是否支持拖拽。
- 为现代暗色主题初始化
ttkbootstrap。
GUI 布局
def _build_ui(self):
main = tb.Frame(self.root, padding=10)
main.pack(fill="both", expand=True)
tb.Label(
main,
text="🌸 Iris Flower Classifier",
font=("Segoe UI", 20, "bold")
).pack(pady=(0, 10))
# 文件选择行
row1 = tb.Frame(main)
row1.pack(fill="x", pady=(0, 6))
self.path_input = tb.Entry(row1, width=80)
self.path_input.pack(side="left", fill="x", expand=True, padx=(0, 6))
self.path_input.insert(0, "Drag & drop CSV files here…")
✅ 说明
- 为文件路径创建输入框。
- 添加简洁、友好的标签。
按钮与进度条
browse_btn = tb.Button(
row1, text="📂 Browse", bootstyle="info", command=self.browse
)
browse_btn.pack(side="left", padx=3)
self.start_btn = tb.Button(
row1, text="🚀 Classify CSV", bootstyle="success", command=self.start
)
self.start_btn.pack(side="left", padx=3)
self.cancel_btn = tb.Button(
row1, text="⏹ Cancel", bootstyle="danger", command=self.cancel
)
self.cancel_btn.pack(side="left", padx=3)
self.cancel_btn.config(state="disabled")
self.progress = tb.Progressbar(
main, bootstyle="success-striped", maximum=100
)
self.progress.pack(fill="x", pady=10)
- Browse – 打开文件对话框。
- Classify CSV – 启动工作线程。
- Cancel – 停止处理。
- 进度条用于可视化分类进度。
(其余方法——browse、start、cancel、_apply_styles 等——在原仓库中实现,负责文件选择、线程管理和 UI 更新。)
运行应用程序
if __name__ == "__main__":
app = IrisClassifierApp()
app.root.mainloop()
就这样!将包含四个 Iris 测量值的 CSV 文件拖入或浏览,点击 Classify,即可在界面不冻结的情况下看到预测结果。祝您在机器学习和 GUI 开发中玩得开心!
GUI 控件
浏览:选择 CSV 文件
分类 CSV:开始预测
取消:停止预测
进度条:可视化反馈
添加手动输入
manual_frame = tb.Labelframe(main, text="Manual Input", padding=10)
manual_frame.pack(fill="x", pady=(10, 6))
labels = ["Sepal Length", "Sepal Width", "Petal Length", "Petal Width"]
self.manual_entries = {}
for i, label in enumerate(labels):
tb.Label(manual_frame, text=label).grid(row=0, column=i * 2, sticky="w")
entry = tb.Entry(manual_frame, width=8)
entry.grid(row=0, column=i * 2 + 1)
entry.insert(0, "0")
self.manual_entries[label] = entry
predict_btn = tb.Button(
manual_frame,
text="🔮 Predict",
bootstyle="info",
command=self.manual_predict,
)
predict_btn.grid(row=0, column=8, padx=10)
self.manual_result = tb.Label(
manual_frame,
text="Prediction: ---",
font=("Segoe UI", 12, "bold"),
)
self.manual_result.grid(row=1, column=0, columnspan=9, pady=(6, 0), sticky="w")
✅ 说明
- 用户可以手动输入花瓣的测量值。
- 点击 Predict 可立即获取物种预测结果。
Run the App (alternative entry)
if __name__ == "__main__":
app = IrisClassifierApp()
app.run()
And that’s it! Your interactive Iris Flower Classifier is ready to run.
Optional Features
- Drag & Drop CSVs (requires
tkinterdnd2) - Export results to a text file
- Beautiful dark/light theme with ttkbootstrap
🎯 摘要
您已经学习了如何:
- 在 Iris 数据集上训练随机森林分类器
- 构建一个与之交互的 Python GUI
- 加载 CSV 文件并进行手动预测
- 可视化进度和结果
查看完整项目:
🔗 Iris-Flower-Classifier-GUI on GitHub
