使用 Python 构建交互式鸢尾花分类器 GUI

发布: (2026年1月19日 GMT+8 09:30)
12 min read
原文: Dev.to

Source: Dev.to

使用 Python 构建交互式鸢尾花分类器 GUI

在本教程中,我们将使用 scikit‑learn 训练一个简单的鸢尾花(Iris)分类模型,并使用 tkinter 为它构建一个图形用户界面(GUI)。用户可以在界面中输入萼片和花瓣的长度/宽度,点击按钮后实时得到预测的花种类。


目录

  1. 准备工作
  2. 加载数据并训练模型
  3. 构建 GUI
  4. 完整代码示例
  5. 运行与演示

准备工作

首先,确保已安装以下 Python 包:

pip install scikit-learn pandas matplotlib

提示:如果你已经有 tkinter(大多数 Python 安装默认自带),则无需额外安装。


加载数据并训练模型

我们使用 sklearn.datasets 中自带的 Iris 数据集。下面的代码完成数据划分、模型训练以及保存模型的步骤。

import pandas as pd
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
import joblib

# 加载数据
iris = load_iris()
X = iris.data          # 四个特征:萼片长度、萼片宽度、花瓣长度、花瓣宽度
y = iris.target        # 0: setosa, 1: versicolor, 2: virginica
feature_names = iris.feature_names
target_names = iris.target_names

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

# 训练随机森林分类器
clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(X_train, y_train)

# 打印模型在测试集上的准确率
print(f"Test accuracy: {clf.score(X_test, y_test):.2%}")

# 将模型保存到磁盘,后续 GUI 中直接加载
joblib.dump(clf, "iris_classifier.pkl")

说明

  • RandomForestClassifier 只是一种示例,你也可以换成 LogisticRegressionSVC 等其他模型。
  • joblib.dump 用于序列化模型,方便在 GUI 中直接加载而无需每次重新训练。

构建 GUI

下面的代码使用 tkinter 创建一个简洁的窗口,包含四个 Entry(用于输入特征值)和一个 Button(用于触发预测)。预测结果会显示在标签 (Label) 中。

import tkinter as tk
from tkinter import ttk, messagebox
import joblib
import numpy as np

# 加载已经训练好的模型
model = joblib.load("iris_classifier.pkl")
target_names = ["setosa", "versicolor", "virginica"]

def predict():
    try:
        # 读取用户输入的四个特征值
        sepal_length = float(entry_sepal_length.get())
        sepal_width  = float(entry_sepal_width.get())
        petal_length = float(entry_petal_length.get())
        petal_width  = float(entry_petal_width.get())
    except ValueError:
        messagebox.showerror("输入错误", "请确保所有字段均为数字。")
        return

    # 将特征值组织成模型需要的形状 (1, 4)
    sample = np.array([[sepal_length, sepal_width, petal_length, petal_width]])
    pred = model.predict(sample)[0]
    result_var.set(f"预测结果: {target_names[pred]}")

# 创建主窗口
root = tk.Tk()
root.title("鸢尾花分类器")
root.geometry("350x300")
root.resizable(False, False)

# 样式
style = ttk.Style()
style.configure("TLabel", font=("Helvetica", 10))
style.configure("TButton", font=("Helvetica", 10, "bold"))

# 输入框标签和控件
ttk.Label(root, text="萼片长度 (cm):").grid(row=0, column=0, padx=10, pady=5, sticky="e")
entry_sepal_length = ttk.Entry(root, width=15)
entry_sepal_length.grid(row=0, column=1, padx=10, pady=5)

ttk.Label(root, text="萼片宽度 (cm):").grid(row=1, column=0, padx=10, pady=5, sticky="e")
entry_sepal_width = ttk.Entry(root, width=15)
entry_sepal_width.grid(row=1, column=1, padx=10, pady=5)

ttk.Label(root, text="花瓣长度 (cm):").grid(row=2, column=0, padx=10, pady=5, sticky="e")
entry_petal_length = ttk.Entry(root, width=15)
entry_petal_length.grid(row=2, column=1, padx=10, pady=5)

ttk.Label(root, text="花瓣宽度 (cm):").grid(row=3, column=0, padx=10, pady=5, sticky="e")
entry_petal_width = ttk.Entry(root, width=15)
entry_petal_width.grid(row=3, column=1, padx=10, pady=5)

# 预测按钮
ttk.Button(root, text="预测", command=predict).grid(row=4, column=0, columnspan=2, pady=15)

# 结果显示
result_var = tk.StringVar(value="预测结果: —")
ttk.Label(root, textvariable=result_var, foreground="blue", font=("Helvetica", 12, "bold")).grid(
    row=5, column=0, columnspan=2, pady=10
)

root.mainloop()

代码要点说明

部分作用
joblib.load("iris_classifier.pkl")从磁盘读取已经训练好的模型,避免每次启动 GUI 时重新训练。
predict() 函数读取用户输入、进行类型检查、调用模型的 predict 方法并更新结果标签。
ttk.Entryttk.Label使用 ttk(主题化小部件)让界面更美观。
messagebox.showerror当用户输入非数字时弹出错误提示,提升用户体验。

完整代码示例

模型训练代码GUI 代码 分别保存为 train_model.pyiris_gui.py,执行顺序如下:

python train_model.py   # 生成 iris_classifier.pkl
python iris_gui.py      # 启动交互式 GUI

小技巧:如果你想一次性运行所有代码,可以把两段代码合并到同一个脚本中,只在第一次运行时训练模型,后续直接加载。


运行与演示

  1. 运行训练脚本

    python train_model.py

    终端会输出类似 Test accuracy: 96.67%,并在当前目录生成 iris_classifier.pkl

  2. 启动 GUI

    python iris_gui.py

    会弹出如下窗口:

    iris_gui_screenshot
    (示例截图,仅供参考)

  3. 使用

    • 在四个输入框中填写萼片/花瓣的长度与宽度(单位:厘米)。
    • 点击 预测 按钮,标签会显示预测的花种类(setosaversicolorvirginica)。

小结

  • 机器学习部分:使用 scikit-learn 快速加载 Iris 数据集并训练随机森林模型。
  • 模型持久化:通过 joblib 将模型保存为二进制文件,便于在 GUI 中直接加载。
  • GUI 实现:利用 tkinter(Python 标准库)构建简洁的交互界面,实现实时预测。
  • 可扩展性:你可以替换模型、添加特征可视化(如 matplotlib)或使用更高级的 GUI 框架(如 PyQtKivy)来提升用户体验。

祝你玩得开心,尽情探索机器学习与桌面应用的结合吧!

我们将使用

  • scikit‑learn – 机器学习
  • pandas – CSV 处理
  • tkinter + ttkbootstrap – GUI
  • 可选: tkinterdnd2 – 拖拽支持

您可以在此处克隆完整的仓库:

🔗 Iris-Flower-Classifier-GUI on GitHub

安装依赖

pip install pandas scikit-learn ttkbootstrap
# Optional: for drag & drop
pip install tkinterdnd2

注意: tkinter 在大多数系统上随 Python 预装。

导入库

import os, sys, threading
import pandas as pd
import tkinter as tk
from tkinter import filedialog, messagebox, ttk

import ttkbootstrap as tb
from ttkbootstrap.constants import *

# Optional drag & drop
try:
    from tkinterdnd2 import TkinterDnD, DND_FILES
    DND_ENABLED = True
except ImportError:
    DND_ENABLED = False
    print("Drag & Drop requires tkinterdnd2: pip install tkinterdnd2")

这里我们导入:

  • pandas 用于处理 CSV 文件
  • tkinterttkbootstrap 用于现代 GUI 元素
  • tkinterdnd2(如果需要)用于拖拽 CSV 支持

创建 Iris 机器学习模型

from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler


class IrisModel:
    def __init__(self):
        data = load_iris()
        self.X = data.data
        self.y = data.target
        self.target_names = data.target_names
        self.scaler = StandardScaler()

        # Scale features
        X_scaled = self.scaler.fit_transform(self.X)

        # Train/test split
        X_train, X_test, y_train, y_test = train_test_split(
            X_scaled, self.y, test_size=0.2, random_state=42
        )

        # Random Forest classifier
        self.clf = RandomForestClassifier(n_estimators=100, random_state=42)
        self.clf.fit(X_train, y_train)

    def predict(self, X):
        X_scaled = self.scaler.transform(X)
        preds = self.clf.predict(X_scaled)
        return [self.target_names[p] for p in preds]

说明

  • 对特征进行标准化,以提升模型性能。
  • 随机森林分类器对初学者友好,并且在此分类任务上表现良好。

CSV 处理工作者

此工作者让 GUI 能够在不冻结界面的情况下处理多个 CSV 文件。

class ClassifierWorker:
    def __init__(self, files, callbacks):
        self.files = files
        self.callbacks = callbacks
        self._running = True
        self.model = IrisModel()

    def stop(self):
        self._running = False

    def run(self):
        total = len(self.files)
        for i, file in enumerate(self.files):
            if not self._running:
                break
            try:
                df = pd.read_csv(file)
                required = {"sepal_length", "sepal_width", "petal_length", "petal_width"}
                if set(df.columns) >= required:
                    X = df[["sepal_length", "sepal_width",
                            "petal_length", "petal_width"]].values
                    preds = self.model.predict(X)
                    if "found" in self.callbacks:
                        self.callbacks["found"](file, preds)
                else:
                    if "found" in self.callbacks:
                        self.callbacks["found"](file, ["Error: Missing required columns"])
            except Exception as e:
                if "found" in self.callbacks:
                    self.callbacks["found"](file, [f"Error: {str(e)}"])
            # Progress update
            if "progress" in self.callbacks:
                self.callbacks["progress"](int((i + 1) / total * 100))
        # Finished
        if "finished" in self.callbacks:
            self.callbacks["finished"]()

说明

  • 逐个读取 CSV 文件。
  • 使用 Iris 模型预测物种。
  • 异步向 GUI 发送进度更新。

构建 GUI

主应用程序类

class IrisClassifierApp:
    def __init__(self):
        if DND_ENABLED:
            self.root = TkinterDnD.Tk()
        else:
            self.root = tb.Window(themename="darkly")
        self.root.title("IrisClassifier v1.1")
        self.root.minsize(1000, 700)

        self.worker_obj = None
        self.file_set = set()
        self.model = IrisModel()

        self._build_ui()
        self._apply_styles()
  • 检查是否支持拖拽。
  • 为现代暗色主题初始化 ttkbootstrap

GUI 布局

def _build_ui(self):
    main = tb.Frame(self.root, padding=10)
    main.pack(fill="both", expand=True)

    tb.Label(
        main,
        text="🌸 Iris Flower Classifier",
        font=("Segoe UI", 20, "bold")
    ).pack(pady=(0, 10))

    # 文件选择行
    row1 = tb.Frame(main)
    row1.pack(fill="x", pady=(0, 6))

    self.path_input = tb.Entry(row1, width=80)
    self.path_input.pack(side="left", fill="x", expand=True, padx=(0, 6))
    self.path_input.insert(0, "Drag & drop CSV files here…")

说明

  • 为文件路径创建输入框。
  • 添加简洁、友好的标签。

按钮与进度条

    browse_btn = tb.Button(
        row1, text="📂 Browse", bootstyle="info", command=self.browse
    )
    browse_btn.pack(side="left", padx=3)

    self.start_btn = tb.Button(
        row1, text="🚀 Classify CSV", bootstyle="success", command=self.start
    )
    self.start_btn.pack(side="left", padx=3)

    self.cancel_btn = tb.Button(
        row1, text="⏹ Cancel", bootstyle="danger", command=self.cancel
    )
    self.cancel_btn.pack(side="left", padx=3)
    self.cancel_btn.config(state="disabled")

    self.progress = tb.Progressbar(
        main, bootstyle="success-striped", maximum=100
    )
    self.progress.pack(fill="x", pady=10)
  • Browse – 打开文件对话框。
  • Classify CSV – 启动工作线程。
  • Cancel – 停止处理。
  • 进度条用于可视化分类进度。

(其余方法——browsestartcancel_apply_styles 等——在原仓库中实现,负责文件选择、线程管理和 UI 更新。)

运行应用程序

if __name__ == "__main__":
    app = IrisClassifierApp()
    app.root.mainloop()

就这样!将包含四个 Iris 测量值的 CSV 文件拖入或浏览,点击 Classify,即可在界面不冻结的情况下看到预测结果。祝您在机器学习和 GUI 开发中玩得开心!

GUI 控件

浏览:选择 CSV 文件

分类 CSV:开始预测

取消:停止预测

进度条:可视化反馈

添加手动输入

manual_frame = tb.Labelframe(main, text="Manual Input", padding=10)
manual_frame.pack(fill="x", pady=(10, 6))

labels = ["Sepal Length", "Sepal Width", "Petal Length", "Petal Width"]
self.manual_entries = {}
for i, label in enumerate(labels):
    tb.Label(manual_frame, text=label).grid(row=0, column=i * 2, sticky="w")
    entry = tb.Entry(manual_frame, width=8)
    entry.grid(row=0, column=i * 2 + 1)
    entry.insert(0, "0")
    self.manual_entries[label] = entry

predict_btn = tb.Button(
    manual_frame,
    text="🔮 Predict",
    bootstyle="info",
    command=self.manual_predict,
)
predict_btn.grid(row=0, column=8, padx=10)

self.manual_result = tb.Label(
    manual_frame,
    text="Prediction: ---",
    font=("Segoe UI", 12, "bold"),
)
self.manual_result.grid(row=1, column=0, columnspan=9, pady=(6, 0), sticky="w")

说明

  • 用户可以手动输入花瓣的测量值。
  • 点击 Predict 可立即获取物种预测结果。

Run the App (alternative entry)

if __name__ == "__main__":
    app = IrisClassifierApp()
    app.run()

And that’s it! Your interactive Iris Flower Classifier is ready to run.

Optional Features

  • Drag & Drop CSVs (requires tkinterdnd2)
  • Export results to a text file
  • Beautiful dark/light theme with ttkbootstrap

🎯 摘要

您已经学习了如何:

  • 在 Iris 数据集上训练随机森林分类器
  • 构建一个与之交互的 Python GUI
  • 加载 CSV 文件并进行手动预测
  • 可视化进度和结果

查看完整项目:

🔗 Iris-Flower-Classifier-GUI on GitHub

鸢尾花分类器 GUI

Back to Blog

相关文章

阅读更多 »

Rapg:基于 TUI 的密钥管理器

我们都有这种经历。你加入一个新项目,首先听到的就是:“在 Slack 的置顶消息里查找 .env 文件”。或者你有多个 .env …

技术是赋能者,而非救世主

为什么思考的清晰度比你使用的工具更重要。Technology 常被视为一种魔法开关——只要打开,它就能让一切改善。新的 software,...

踏入 agentic coding

使用 Copilot Agent 的经验 我主要使用 GitHub Copilot 进行 inline edits 和 PR reviews,让我的大脑完成大部分思考。最近我决定 t...