数据分析/机器学习流水线中的 Repository Pattern
Source: Dev.to
请提供您希望翻译的具体文本内容,我将为您翻译成简体中文并保持原有的格式、Markdown 语法以及技术术语不变。谢谢!
1. 大多数机器学习项目的核心问题
让我们从一个非常真实的例子开始。
df = pd.read_sql("""
SELECT customer_id,
SUM(amount) AS total_amount
FROM transactions
WHERE transaction_date >= '2025-01-01'
GROUP BY customer_id
""", conn)
乍一看,这似乎很高效。
但随着时间推移:
- SQL 语句越来越多
- 业务逻辑悄悄混进查询中
- 特征逻辑被重复实现
- 没有人知道哪个查询为哪个模型提供数据
- 测试变得异常痛苦
最终,机器学习流水线会变成一个紧耦合的混乱系统。
👉 仓库模式(Repository Pattern)正是为防止这种情况出现而存在的。
2. 大想法(通俗解释)
“我的机器学习流水线不应该知道数据来自何处——只需要知道它需要什么数据。”
所以不要说:
“使用这个 SQL 查询从 MySQL 获取数据”
你的流水线应该说:
“获取这些日期之间的所有交易”
就这么简单。
3. 心智模型:仓库作为数据翻译器
将仓库视为以下之间的翻译器:
| 🧠 业务 / 机器学习逻辑 | 🗄️ 物理数据存储 (MySQL) |
|---|---|
| ML 流水线说: “事务” | 数据库说: “表、连接、SQL” |
| 仓库两者皆通 |
仓库隐藏了:
- SQL
- 连接处理
- 数据库特性
- 性能调优
4. 使用仓库模式的机器学习流水线(概念视图)
MySQL Database
|
| (SQL, connections, credentials)
v
Repository Layer
|
| (clean Python objects)
v
Feature Engineering
|
v
Model Training
重要规则:
流水线只能依赖仓库的契约,而不能依赖底层存储。
5. 第一步:定义管道关注的内容(领域模型)
在分析工作中,我们不需要笨重的 ORM 实体——只需要有意义的数据结构。
from dataclasses import dataclass
from datetime import date
@dataclass
class Transaction:
customer_id: int
amount: float
transaction_date: date
transaction_type: str
为什么这很重要
- 没有 SQL
- 没有 MySQL
- 没有 Pandas
- 纯 Python
这就是 领域语言:分析师和机器学习工程师都能理解的表达方式。
6. 步骤 2:定义仓库契约(承诺)
现在我们要问:机器学习流水线需要什么数据?
不是如何获取,而是什么。
from abc import ABC, abstractmethod
from typing import List
from datetime import date
class TransactionRepository(ABC):
@abstractmethod
def get_transactions(
self,
start_date: date,
end_date: date
) -> List[Transaction]:
pass
关键理念
这是一种承诺:任何数据源、任何数据库、任何存储引擎——只要它满足此契约即可。
7. 为什么这很强大
此时:
- ML 管道依赖于 接口
- 它 不依赖 MySQL
- 它 不依赖 PyMySQL
这为您提供:
- 可测试性
- 灵活性
- 清晰的设计
Source: …
8. 第 3 步:使用 PyMySQL 实现仓库
现在——仅此时——我们才接触 MySQL。
连接助手
import pymysql
def get_connection():
return pymysql.connect(
host="localhost",
user="analytics_user",
password="analytics_pwd",
database="analytics_db",
cursorclass=pymysql.cursors.DictCursor
)
基于 MySQL 的仓库
class MySQLTransactionRepository(TransactionRepository):
def get_transactions(self, start_date, end_date):
query = """
SELECT customer_id,
amount,
transaction_date,
transaction_type
FROM transactions
WHERE transaction_date BETWEEN %s AND %s
"""
conn = get_connection()
try:
with conn.cursor() as cursor:
cursor.execute(query, (start_date, end_date))
rows = cursor.fetchall()
return [
Transaction(
customer_id=row["customer_id"],
amount=float(row["amount"]),
transaction_date=row["transaction_date"],
transaction_type=row["transaction_type"]
)
for row in rows
]
finally:
conn.close()
刚才发生了什么?
- SQL 被隔离
- 连接生命周期受到控制
- 原始行被转换为 领域对象
系统中的其他部分保持干净。
9. 第4步:特征工程(纯数据逻辑)
此层不知道:
- 数据来源
- 查询方式
- 是否为 MySQL、CSV 或 API
import pandas as pd
class TransactionFeatureEngineer:
def build_customer_features(self, transactions):
df = pd.DataFrame([t.__dict__ for t in transactions])
features = (
df.groupby("customer_id")
.agg(
total_amount=("amount", "sum"),
avg_amount=("amount", "mean"),
txn_count=("amount", "count")
)
.reset_index()
)
return features
为什么这样干净
- 确定性
- 易于单元测试
- 无副作用
- 可在多个模型间复用
10. Step 5: Model Training Layer
Again — no database awareness.
from sklearn.ensemble import RandomForestClassifier
class ModelTrainingService:
def train(self, X, y):
model = RandomForestClassifier(
n_estimators=100,
random_state=42
)
model.fit(X, y)
return model
The model only cares about features, not data sources.
11. Step 6: 编排管道
This is the only place everything comes together.
from datetime import date
repo = MySQLTransactionRepository()
feature_engineer = TransactionFeatureEngineer()
trainer = ModelTrainingService()
# Fetch
transactions = repo.get_transactions(
date(2025, 1, 1),
date(2025, 12, 31)
)
# Features
features_df = feature_engineer.build_customer_features(transactions)
# Example target
features_df["target"] = (
features_df["total_amount"] > 100000
).astype(int)
X = features_df[["total_amount", "avg_amount", "txn_count"]]
y = features_df["target"]
# Train
model = traine
注意: 最后一行 (
model = traine) 故意保持原样,以保留原始内容。
r.train(X, y)
这段代码读起来像故事,而不是管道。
12. 真正的超级能力:在没有 MySQL 的情况下进行测试
Now comes the magic: an in‑memory repository.
class InMemoryTransactionRepository(TransactionRepository):
def __init__(self, transactions):
self.transactions = transactions
def get_transactions(self, start_date, end_date):
return [
t for t in self.transactions
if start_date
]
Repositories answer “WHAT data?”
将这些问题分开——你的数据管道就能保持理智。