数据分析/机器学习流水线中的 Repository Pattern

发布: (2026年2月3日 GMT+8 01:17)
6 分钟阅读
原文: Dev.to

Source: Dev.to

请提供您希望翻译的具体文本内容,我将为您翻译成简体中文并保持原有的格式、Markdown 语法以及技术术语不变。谢谢!

1. 大多数机器学习项目的核心问题

让我们从一个非常真实的例子开始。

df = pd.read_sql("""
    SELECT customer_id,
           SUM(amount) AS total_amount
    FROM transactions
    WHERE transaction_date >= '2025-01-01'
    GROUP BY customer_id
""", conn)

乍一看,这似乎很高效。
但随着时间推移:

  • SQL 语句越来越多
  • 业务逻辑悄悄混进查询中
  • 特征逻辑被重复实现
  • 没有人知道哪个查询为哪个模型提供数据
  • 测试变得异常痛苦

最终,机器学习流水线会变成一个紧耦合的混乱系统。

👉 仓库模式(Repository Pattern)正是为防止这种情况出现而存在的。

2. 大想法(通俗解释)

“我的机器学习流水线不应该知道数据来自何处——只需要知道它需要什么数据。”

所以不要说:

“使用这个 SQL 查询从 MySQL 获取数据”

你的流水线应该说:

“获取这些日期之间的所有交易”

就这么简单。

3. 心智模型:仓库作为数据翻译器

将仓库视为以下之间的翻译器:

🧠 业务 / 机器学习逻辑🗄️ 物理数据存储 (MySQL)
ML 流水线说: “事务”数据库说: “表、连接、SQL”
仓库两者皆通

仓库隐藏了:

  • SQL
  • 连接处理
  • 数据库特性
  • 性能调优

4. 使用仓库模式的机器学习流水线(概念视图)

MySQL Database
     |
     |  (SQL, connections, credentials)
     v
Repository Layer
     |
     |  (clean Python objects)
     v
Feature Engineering
     |
     v
Model Training

重要规则:

流水线只能依赖仓库的契约,而不能依赖底层存储。

5. 第一步:定义管道关注的内容(领域模型)

在分析工作中,我们不需要笨重的 ORM 实体——只需要有意义的数据结构。

from dataclasses import dataclass
from datetime import date

@dataclass
class Transaction:
    customer_id: int
    amount: float
    transaction_date: date
    transaction_type: str

为什么这很重要

  • 没有 SQL
  • 没有 MySQL
  • 没有 Pandas
  • 纯 Python

这就是 领域语言:分析师和机器学习工程师都能理解的表达方式。

6. 步骤 2:定义仓库契约(承诺)

现在我们要问:机器学习流水线需要什么数据?
不是如何获取,而是什么

from abc import ABC, abstractmethod
from typing import List
from datetime import date

class TransactionRepository(ABC):
    @abstractmethod
    def get_transactions(
        self,
        start_date: date,
        end_date: date
    ) -> List[Transaction]:
        pass

关键理念

这是一种承诺:任何数据源、任何数据库、任何存储引擎——只要它满足此契约即可。

7. 为什么这很强大

此时:

  • ML 管道依赖于 接口
  • 不依赖 MySQL
  • 不依赖 PyMySQL

这为您提供:

  • 可测试性
  • 灵活性
  • 清晰的设计

Source:

8. 第 3 步:使用 PyMySQL 实现仓库

现在——仅此时——我们才接触 MySQL。

连接助手

import pymysql

def get_connection():
    return pymysql.connect(
        host="localhost",
        user="analytics_user",
        password="analytics_pwd",
        database="analytics_db",
        cursorclass=pymysql.cursors.DictCursor
    )

基于 MySQL 的仓库

class MySQLTransactionRepository(TransactionRepository):
    def get_transactions(self, start_date, end_date):
        query = """
            SELECT customer_id,
                   amount,
                   transaction_date,
                   transaction_type
            FROM transactions
            WHERE transaction_date BETWEEN %s AND %s
        """

        conn = get_connection()
        try:
            with conn.cursor() as cursor:
                cursor.execute(query, (start_date, end_date))
                rows = cursor.fetchall()

            return [
                Transaction(
                    customer_id=row["customer_id"],
                    amount=float(row["amount"]),
                    transaction_date=row["transaction_date"],
                    transaction_type=row["transaction_type"]
                )
                for row in rows
            ]
        finally:
            conn.close()

刚才发生了什么?

  • SQL 被隔离
  • 连接生命周期受到控制
  • 原始行被转换为 领域对象

系统中的其他部分保持干净。

9. 第4步:特征工程(纯数据逻辑)

此层不知道:

  • 数据来源
  • 查询方式
  • 是否为 MySQL、CSV 或 API
import pandas as pd

class TransactionFeatureEngineer:
    def build_customer_features(self, transactions):
        df = pd.DataFrame([t.__dict__ for t in transactions])

        features = (
            df.groupby("customer_id")
              .agg(
                  total_amount=("amount", "sum"),
                  avg_amount=("amount", "mean"),
                  txn_count=("amount", "count")
              )
              .reset_index()
        )
        return features

为什么这样干净

  • 确定性
  • 易于单元测试
  • 无副作用
  • 可在多个模型间复用

10. Step 5: Model Training Layer

Again — no database awareness.

from sklearn.ensemble import RandomForestClassifier

class ModelTrainingService:
    def train(self, X, y):
        model = RandomForestClassifier(
            n_estimators=100,
            random_state=42
        )
        model.fit(X, y)
        return model

The model only cares about features, not data sources.

11. Step 6: 编排管道

This is the only place everything comes together.

from datetime import date

repo = MySQLTransactionRepository()
feature_engineer = TransactionFeatureEngineer()
trainer = ModelTrainingService()

# Fetch
transactions = repo.get_transactions(
    date(2025, 1, 1),
    date(2025, 12, 31)
)

# Features
features_df = feature_engineer.build_customer_features(transactions)

# Example target
features_df["target"] = (
    features_df["total_amount"] > 100000
).astype(int)

X = features_df[["total_amount", "avg_amount", "txn_count"]]
y = features_df["target"]

# Train
model = traine

注意: 最后一行 (model = traine) 故意保持原样,以保留原始内容。

r.train(X, y)

这段代码读起来像故事,而不是管道。

12. 真正的超级能力:在没有 MySQL 的情况下进行测试

Now comes the magic: an in‑memory repository.

class InMemoryTransactionRepository(TransactionRepository):
    def __init__(self, transactions):
        self.transactions = transactions

    def get_transactions(self, start_date, end_date):
        return [
            t for t in self.transactions
            if start_date
        ]

Repositories answer “WHAT data?”

将这些问题分开——你的数据管道就能保持理智。

Back to Blog

相关文章

阅读更多 »