本文由claude code生成。作为备忘录和分享,欢迎指正和补充。

一、Ent简介

Ent是一个简单但功能强大的实体框架,用于Go语言,使构建和维护具有大型数据模型的应用程序变得容易。 Ent是由Meta(Facebook)开源团队构建的ORM框架,提供了一个API,用于将任何数据库模式建模为Go对象。

核心特性

Ent的主要特性包括:

  • Schema As Code:将任何数据库模式建模为Go对象
  • 轻松遍历任何图结构:运行查询、聚合并轻松遍历任何图结构
  • 静态类型和显式API:100%静态类型和显式API,通过代码生成实现
  • 多存储驱动支持:支持MySQL、PostgreSQL、SQLite和Gremlin
  • 易于扩展:使用Go模板简单扩展和自定义

Ent是一个相当新的ORM,使用代码优先的方法,在Go代码中定义模式。Ent之所以流行,是因为它能够优雅地处理复杂的数据模型和关系。

二、快速开始

2.1 安装Ent

# 初始化Go模块
go mod init myapp

# 安装ent CLI工具
go install entgo.io/ent/cmd/ent@latest

# 或者使用go get
go get entgo.io/ent/cmd/ent

2.2 创建第一个Schema

# 创建User实体
ent new User

这会在 ent/schema/ 目录下生成一个基础的schema文件:

package schema

import (
    "entgo.io/ent"
    "entgo.io/ent/schema/field"
)

// User holds the schema definition for the User entity.
type User struct {
    ent.Schema
}

// Fields of the User.
func (User) Fields() []ent.Field {
    return []ent.Field{
        field.String("name").
            NotEmpty(),
        field.Int("age").
            Positive(),
        field.String("email").
            Unique(),
        field.Time("created_at").
            Default(time.Now),
    }
}

// Edges of the User.
func (User) Edges() []ent.Edge {
    return nil
}

2.3 生成代码

go generate ./ent

2.4 连接数据库并使用

package main

import (
    "context"
    "log"

    "myapp/ent"

    _ "github.com/lib/pq"           // PostgreSQL
    // _ "github.com/go-sql-driver/mysql"  // MySQL
    // _ "github.com/mattn/go-sqlite3"     // SQLite
)

func main() {
    // 连接PostgreSQL
    client, err := ent.Open("postgres",
        "host=localhost port=5432 user=postgres dbname=myapp password=password sslmode=disable")
    if err != nil {
        log.Fatalf("failed opening connection to postgres: %v", err)
    }
    defer client.Close()

    // 运行自动迁移
    if err := client.Schema.Create(context.Background()); err != nil {
        log.Fatalf("failed creating schema resources: %v", err)
    }

    ctx := context.Background()

    // 创建用户
    user, err := client.User.
        Create().
        SetName("张三").
        SetAge(30).
        SetEmail("zhangsan@example.com").
        Save(ctx)
    if err != nil {
        log.Fatalf("failed creating user: %v", err)
    }
    log.Println("user was created: ", user)
}

2.5 使用SQLite3(非CGO驱动)

📋 使用场景

为什么使用非CGO的SQLite驱动?

  • 跨平台编译 - 不需要CGO,可以轻松交叉编译
  • 部署简单 - 不依赖SQLite C库
  • 开发便捷 - 不需要配置CGO环境
  • 完全兼容 - 功能与CGO版本完全一致

适用场景:

  • 嵌入式应用或桌面应用
  • 开发和测试环境
  • 轻量级数据存储
  • 需要跨平台编译的项目

安装和配置

# 安装非CGO的SQLite驱动
go get github.com/ncruces/go-sqlite3

基本使用

package main

import (
    "context"
    "log"

    "myapp/ent"

    _ "github.com/ncruces/go-sqlite3/driver"
    _ "github.com/ncruces/go-sqlite3/embed"
)

func main() {
    // 使用文件模式连接SQLite
    // cache=shared: 允许多个连接共享缓存
    // _pragma: 设置SQLite参数
    dsn := "file:./data.db?cache=shared" +
        "&_pragma=foreign_keys(1)" +      // 启用外键约束
        "&_pragma=journal_mode(WAL)" +    // 使用WAL模式提升并发性能
        "&_pragma=synchronous(NORMAL)" +  // 平衡性能和安全性
        "&_pragma=busy_timeout(10000)" +  // 10秒超时
        "&_pragma=cache_size(-64000)"     // 64MB缓存
    
    client, err := ent.Open("sqlite3", dsn)
    if err != nil {
        log.Fatalf("failed opening connection to sqlite: %v", err)
    }
    defer client.Close()

    // 运行自动迁移
    if err := client.Schema.Create(context.Background()); err != nil {
        log.Fatalf("failed creating schema resources: %v", err)
    }

    ctx := context.Background()

    // 创建用户
    user, err := client.User.
        Create().
        SetName("张三").
        SetAge(30).
        SetEmail("zhangsan@example.com").
        Save(ctx)
    if err != nil {
        log.Fatalf("failed creating user: %v", err)
    }
    log.Println("user was created: ", user)
}

内存模式

// 使用内存数据库(测试场景)
func NewTestClient() *ent.Client {
    client, err := ent.Open("sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
    if err != nil {
        log.Fatalf("failed opening in-memory db: %v", err)
    }
    
    // 运行迁移
    if err := client.Schema.Create(context.Background()); err != nil {
        log.Fatalf("failed creating schema resources: %v", err)
    }
    
    return client
}

SQLite优化配置

package database

import (
    "database/sql"
    "time"

    "entgo.io/ent/dialect"
    entsql "entgo.io/ent/dialect/sql"
    _ "github.com/ncruces/go-sqlite3/driver"
    _ "github.com/ncruces/go-sqlite3/embed"
    
    "myapp/ent"
)

func NewSQLiteClient(dbPath string) (*ent.Client, error) {
    // 构建DSN
    dsn := "file:" + dbPath + 
        "?cache=shared" +
        "&_pragma=foreign_keys(1)" +
        "&_pragma=journal_mode(WAL)" +
        "&_pragma=synchronous(NORMAL)" +
        "&_pragma=busy_timeout(10000)" +
        "&_pragma=cache_size(-64000)" +
        "&_pragma=temp_store(MEMORY)" +
        "&_pragma=mmap_size(268435456)"  // 256MB mmap
    
    // 打开数据库
    db, err := sql.Open("sqlite3", dsn)
    if err != nil {
        return nil, err
    }
    
    // 配置连接池(SQLite建议单连接)
    db.SetMaxOpenConns(1)
    db.SetMaxIdleConns(1)
    db.SetConnMaxLifetime(time.Hour)
    
    // 创建Ent客户端
    drv := entsql.OpenDB(dialect.SQLite, db)
    client := ent.NewClient(ent.Driver(drv))
    
    return client, nil
}

使用lib-x/entsqlite

// 另一个流行的非CGO SQLite驱动
import _ "github.com/lib-x/entsqlite"

func main() {
    dsn := "file:./data.db?cache=shared" +
        "&_pragma=foreign_keys(1)" +
        "&_pragma=journal_mode(WAL)" +
        "&_pragma=synchronous(NORMAL)" +
        "&_pragma=busy_timeout(10000)"
    
    client, err := ent.Open("sqlite3", dsn)
    if err != nil {
        log.Fatalf("failed opening connection: %v", err)
    }
    defer client.Close()
    
    // 使用client...
}

SQLite特性说明

WAL模式(Write-Ahead Logging):

  • 提升并发读写性能
  • 允许读操作和写操作并行
  • 生产环境推荐使用

同步模式:

  • FULL: 最安全,性能最慢
  • NORMAL: 平衡(推荐)
  • OFF: 最快,但可能丢数据

忙等待超时:

  • 当数据库被锁定时等待的时间
  • 10000毫秒(10秒)对大多数应用合适

缓存大小:

  • 负数表示KB,例如 -64000 = 64MB
  • 增加缓存可以提升性能

三、Schema定义详解

3.1 字段类型(Fields)

Ent支持多种字段类型,每种类型都有丰富的配置选项:

import (
    "time"
    "regexp"

    "entgo.io/ent"
    "entgo.io/ent/schema/field"
)

type User struct {
    ent.Schema
}

func (User) Fields() []ent.Field {
    return []ent.Field{
        // 基础类型
        field.Int("id"),
        field.Int64("big_id"),
        field.Float("score"),
        field.Bool("active"),
        field.String("name"),
        field.Bytes("avatar"),
        field.Time("created_at"),
        field.UUID("uuid", uuid.UUID{}),
        field.JSON("metadata", map[string]interface{}{}),
        
        // 枚举类型
        field.Enum("status").
            Values("pending", "active", "inactive", "deleted"),
        
        // 带验证的字段
        field.String("email").
            Match(regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)).
            Unique(),
        
        // 可选字段(允许NULL)
        field.String("nickname").
            Optional().
            Nillable(),
        
        // 带默认值
        field.Time("created_at").
            Default(time.Now).
            Immutable(), // 创建后不可修改
        
        // 自定义校验
        field.Int("age").
            Positive().
            Max(150),
        
        // 敏感字段(不会在日志中显示)
        field.String("password").
            Sensitive(),
        
        // 带注释
        field.String("description").
            Comment("用户描述信息"),
        
        // 存储键名映射
        field.String("phone_number").
            StorageKey("phone"),
        
        // 结构体嵌入
        field.JSON("address", &Address{}).
            Optional(),
    }
}

type Address struct {
    Street  string `json:"street"`
    City    string `json:"city"`
    Country string `json:"country"`
}

3.2 关系定义(Edges)

Edges代表Ent中的关系(一对一、一对多、多对多)。

package schema

import (
    "entgo.io/ent"
    "entgo.io/ent/schema/edge"
)

// User schema
type User struct {
    ent.Schema
}

func (User) Edges() []ent.Edge {
    return []ent.Edge{
        // 一对多:一个用户有多篇文章
        edge.To("posts", Post.Type),
        
        // 一对一:一个用户有一个档案
        edge.To("profile", Profile.Type).
            Unique(),
        
        // 多对多:用户可以属于多个群组
        edge.To("groups", Group.Type),
        
        // 自引用:用户的朋友关系
        edge.To("friends", User.Type),
        
        // 带边缘属性的多对多关系
        edge.To("following", User.Type).
            Through("user_following", UserFollowing.Type),
    }
}

// Post schema
type Post struct {
    ent.Schema
}

func (Post) Edges() []ent.Edge {
    return []ent.Edge{
        // 反向边:文章属于一个用户
        edge.From("author", User.Type).
            Ref("posts").
            Unique().
            Required(),
        
        // 多对多:文章可以有多个标签
        edge.To("tags", Tag.Type),
    }
}

// Profile schema (一对一示例)
type Profile struct {
    ent.Schema
}

func (Profile) Edges() []ent.Edge {
    return []ent.Edge{
        edge.From("user", User.Type).
            Ref("profile").
            Unique().
            Required(),
    }
}

// Group schema
type Group struct {
    ent.Schema
}

func (Group) Edges() []ent.Edge {
    return []ent.Edge{
        edge.From("members", User.Type).
            Ref("groups"),
    }
}

// UserFollowing - 用于存储关注关系的边缘属性
type UserFollowing struct {
    ent.Schema
}

func (UserFollowing) Fields() []ent.Field {
    return []ent.Field{
        field.Time("followed_at").
            Default(time.Now),
    }
}

func (UserFollowing) Edges() []ent.Edge {
    return []ent.Edge{
        edge.To("follower", User.Type).
            Unique().
            Required(),
        edge.To("following", User.Type).
            Unique().
            Required(),
    }
}

3.3 索引定义(Indexes)

package schema

import (
    "entgo.io/ent"
    "entgo.io/ent/schema/index"
)

type User struct {
    ent.Schema
}

func (User) Indexes() []ent.Index {
    return []ent.Index{
        // 单字段索引
        index.Fields("email"),
        
        // 复合索引
        index.Fields("name", "age"),
        
        // 唯一索引
        index.Fields("username").
            Unique(),
        
        // 复合唯一索引
        index.Fields("tenant_id", "email").
            Unique(),
        
        // 边和字段的联合索引
        index.Fields("name").
            Edges("parent"),
    }
}

3.4 Mixin(混入)

Mixin是一个接口,允许创建可重用的schema片段,可以注入到另一个schema中。Mixin可以是一组字段、边或钩子。

package schema

import (
    "time"

    "entgo.io/ent"
    "entgo.io/ent/schema/field"
    "entgo.io/ent/schema/mixin"
)

// TimeMixin 提供创建时间和更新时间字段
type TimeMixin struct {
    mixin.Schema
}

func (TimeMixin) Fields() []ent.Field {
    return []ent.Field{
        field.Time("created_at").
            Immutable().
            Default(time.Now),
        field.Time("updated_at").
            Default(time.Now).
            UpdateDefault(time.Now),
    }
}

// SoftDeleteMixin 提供软删除功能
type SoftDeleteMixin struct {
    mixin.Schema
}

func (SoftDeleteMixin) Fields() []ent.Field {
    return []ent.Field{
        field.Time("deleted_at").
            Optional().
            Nillable(),
    }
}

// TenantMixin 提供多租户支持
type TenantMixin struct {
    mixin.Schema
}

func (TenantMixin) Fields() []ent.Field {
    return []ent.Field{
        field.Int("tenant_id").
            Immutable(),
    }
}

// User使用Mixin
type User struct {
    ent.Schema
}

func (User) Mixin() []ent.Mixin {
    return []ent.Mixin{
        TimeMixin{},
        SoftDeleteMixin{},
        TenantMixin{},
    }
}

func (User) Fields() []ent.Field {
    return []ent.Field{
        field.String("name"),
        field.String("email").Unique(),
    }
}

四、CRUD操作详解

4.1 创建操作(Create)

func CreateExamples(ctx context.Context, client *ent.Client) {
    // 创建单个实体
    user, err := client.User.
        Create().
        SetName("张三").
        SetEmail("zhangsan@example.com").
        SetAge(25).
        Save(ctx)
    
    // 创建并返回ID
    id, err := client.User.
        Create().
        SetName("李四").
        SetEmail("lisi@example.com").
        SaveX(ctx).ID // SaveX会panic如果出错
    
    // 批量创建
    users, err := client.User.
        CreateBulk(
            client.User.Create().SetName("用户1").SetEmail("user1@example.com"),
            client.User.Create().SetName("用户2").SetEmail("user2@example.com"),
            client.User.Create().SetName("用户3").SetEmail("user3@example.com"),
        ).
        Save(ctx)
    
    // 创建并关联关系
    post, err := client.Post.
        Create().
        SetTitle("我的第一篇文章").
        SetContent("内容...").
        SetAuthor(user). // 关联用户
        AddTags(tag1, tag2). // 添加多个标签
        Save(ctx)
    
    // Upsert(存在则更新,不存在则创建)
    err = client.User.
        Create().
        SetEmail("zhangsan@example.com").
        SetName("张三").
        OnConflict(
            sql.ConflictColumns("email"),
        ).
        UpdateNewValues().
        Exec(ctx)
    
    // Upsert批量
    err = client.User.
        CreateBulk(builders...).
        OnConflict(
            sql.ConflictColumns("email"),
        ).
        UpdateNewValues().
        Exec(ctx)
}

4.2 查询操作(Query)

func QueryExamples(ctx context.Context, client *ent.Client) {
    // 查询单个实体 - 通过ID
    user, err := client.User.Get(ctx, 1)
    
    // 查询单个实体 - 通过条件
    user, err = client.User.
        Query().
        Where(user.Email("zhangsan@example.com")).
        Only(ctx) // 只返回一个,多于一个会报错
    
    // First - 返回第一个
    user, err = client.User.
        Query().
        Where(user.AgeGT(18)).
        First(ctx)
    
    // 查询多个实体
    users, err := client.User.
        Query().
        Where(
            user.Or(
                user.AgeGT(18),
                user.NameContains("张"),
            ),
        ).
        All(ctx)
    
    // 条件组合 - 复杂查询
    users, err = client.User.
        Query().
        Where(
            user.And(
                user.AgeGTE(18),
                user.AgeLTE(60),
                user.Or(
                    user.StatusEQ("active"),
                    user.StatusEQ("pending"),
                ),
                user.Not(
                    user.NameHasPrefix("test_"),
                ),
            ),
        ).
        All(ctx)
    
    // 排序
    users, err = client.User.
        Query().
        Order(ent.Desc(user.FieldCreatedAt)).
        Order(ent.Asc(user.FieldName)).
        All(ctx)
    
    // 分页
    users, err = client.User.
        Query().
        Offset(0).
        Limit(10).
        All(ctx)
    
    // 选择特定字段
    names, err := client.User.
        Query().
        Select(user.FieldName, user.FieldEmail).
        Strings(ctx)
    
    // 去重
    names, err = client.User.
        Query().
        Unique(true).
        Select(user.FieldName).
        Strings(ctx)
    
    // 计数
    count, err := client.User.
        Query().
        Where(user.StatusEQ("active")).
        Count(ctx)
    
    // 判断是否存在
    exists, err := client.User.
        Query().
        Where(user.Email("test@example.com")).
        Exist(ctx)
    
    // 聚合查询
    var result []struct {
        Status string
        Count  int
        SumAge int
        AvgAge float64
        MaxAge int
        MinAge int
    }
    err = client.User.
        Query().
        GroupBy(user.FieldStatus).
        Aggregate(
            ent.Count(),
            ent.Sum(user.FieldAge),
            ent.Mean(user.FieldAge),
            ent.Max(user.FieldAge),
            ent.Min(user.FieldAge),
        ).
        Scan(ctx, &result)
    
    // Having子句
    err = client.User.
        Query().
        GroupBy(user.FieldStatus).
        Aggregate(ent.Count()).
        Having(
            sql.GT(sql.Count("*"), 10),
        ).
        Scan(ctx, &result)
}

4.3 关系查询(Edge Query)

func EdgeQueryExamples(ctx context.Context, client *ent.Client) {
    // 查询用户的所有文章
    user, _ := client.User.Get(ctx, 1)
    posts, err := user.QueryPosts().All(ctx)
    
    // 带条件的关系查询
    posts, err = user.
        QueryPosts().
        Where(post.StatusEQ("published")).
        Order(ent.Desc(post.FieldCreatedAt)).
        Limit(10).
        All(ctx)
    
    // 预加载(Eager Loading)- 避免N+1问题
    users, err := client.User.
        Query().
        WithPosts(). // 预加载文章
        WithProfile(). // 预加载档案
        WithGroups(func(q *ent.GroupQuery) {
            // 可以对预加载的关系添加条件
            q.Where(group.ActiveEQ(true))
            q.Order(ent.Asc(group.FieldName))
        }).
        All(ctx)
    
    // 访问预加载的数据
    for _, u := range users {
        fmt.Println("User:", u.Name)
        for _, p := range u.Edges.Posts {
            fmt.Println("  Post:", p.Title)
        }
        if u.Edges.Profile != nil {
            fmt.Println("  Bio:", u.Edges.Profile.Bio)
        }
    }
    
    // 深层嵌套预加载
    users, err = client.User.
        Query().
        WithPosts(func(q *ent.PostQuery) {
            q.WithTags()
            q.WithComments(func(cq *ent.CommentQuery) {
                cq.WithAuthor()
            })
        }).
        All(ctx)
    
    // 通过关系查询(反向查询)
    // 查询发表了某篇文章的用户
    author, err := client.Post.
        Query().
        Where(post.IDEQ(1)).
        QueryAuthor().
        Only(ctx)
    
    // 链式关系查询
    // 查询用户朋友的所有文章
    posts, err = client.User.
        Query().
        Where(user.IDEQ(1)).
        QueryFriends().
        QueryPosts().
        Where(post.StatusEQ("published")).
        All(ctx)
    
    // Named Edge(命名边加载)
    users, err = client.User.
        Query().
        WithNamedPosts("recent_posts", func(q *ent.PostQuery) {
            q.Order(ent.Desc(post.FieldCreatedAt)).Limit(5)
        }).
        WithNamedPosts("popular_posts", func(q *ent.PostQuery) {
            q.Order(ent.Desc(post.FieldViewCount)).Limit(5)
        }).
        All(ctx)
    
    // 访问命名边
    for _, u := range users {
        recentPosts, _ := u.NamedPosts("recent_posts")
        popularPosts, _ := u.NamedPosts("popular_posts")
    }
}

4.4 更新操作(Update)

func UpdateExamples(ctx context.Context, client *ent.Client) {
    // 更新单个实体(通过ID)
    user, err := client.User.
        UpdateOneID(1).
        SetName("新名字").
        SetAge(26).
        Save(ctx)
    
    // 更新单个实体(通过实体对象)
    user, _ = client.User.Get(ctx, 1)
    user, err = user.
        Update().
        SetName("新名字").
        AddAge(1). // 年龄+1
        Save(ctx)
    
    // 条件更新(批量)
    affected, err := client.User.
        Update().
        Where(user.StatusEQ("inactive")).
        SetStatus("deleted").
        Save(ctx)
    fmt.Printf("更新了 %d 条记录\n", affected)
    
    // 清除可选字段
    user, err = client.User.
        UpdateOneID(1).
        ClearNickname(). // 设置为NULL
        ClearDeletedAt().
        Save(ctx)
    
    // 更新关系
    user, err = client.User.
        UpdateOneID(1).
        AddPostIDs(1, 2, 3). // 添加文章关联
        RemovePostIDs(4, 5). // 移除文章关联
        ClearPosts().        // 清除所有文章关联
        SetProfileID(1).     // 设置档案
        Save(ctx)
    
    // 添加/移除多对多关系
    user, err = client.User.
        UpdateOneID(1).
        AddGroups(group1, group2).
        RemoveGroups(group3).
        Save(ctx)
    
    // 数值字段操作
    post, err := client.Post.
        UpdateOneID(1).
        AddViewCount(1).      // 浏览量+1
        AddLikeCount(1).      // 点赞+1
        Save(ctx)
    
    // 使用Modifier进行更复杂的更新
    err = client.User.
        Update().
        Where(user.IDEQ(1)).
        Modify(func(u *sql.UpdateBuilder) {
            u.Set("login_count", sql.Expr("login_count + 1"))
            u.Set("last_login_at", sql.Expr("NOW()"))
        }).
        Exec(ctx)
}

4.5 删除操作(Delete)

func DeleteExamples(ctx context.Context, client *ent.Client) {
    // 删除单个实体(通过ID)
    err := client.User.DeleteOneID(1).Exec(ctx)
    
    // 删除单个实体(通过实体对象)
    user, _ := client.User.Get(ctx, 1)
    err = client.User.DeleteOne(user).Exec(ctx)
    
    // 条件删除(批量)
    affected, err := client.User.
        Delete().
        Where(
            user.And(
                user.StatusEQ("deleted"),
                user.DeletedAtLT(time.Now().AddDate(0, -6, 0)), // 6个月前删除的
            ),
        ).
        Exec(ctx)
    fmt.Printf("删除了 %d 条记录\n", affected)
    
    // 清空表(谨慎使用)
    affected, err = client.User.Delete().Exec(ctx)
}

五、事务处理

5.1 基本事务

func TransactionExample(ctx context.Context, client *ent.Client) error {
    // 开始事务
    tx, err := client.Tx(ctx)
    if err != nil {
        return err
    }
    
    // 创建用户
    user, err := tx.User.
        Create().
        SetName("张三").
        SetEmail("zhangsan@example.com").
        Save(ctx)
    if err != nil {
        return rollback(tx, err)
    }
    
    // 创建档案
    _, err = tx.Profile.
        Create().
        SetBio("这是个人简介").
        SetUser(user).
        Save(ctx)
    if err != nil {
        return rollback(tx, err)
    }
    
    // 提交事务
    return tx.Commit()
}

// 回滚辅助函数
func rollback(tx *ent.Tx, err error) error {
    if rerr := tx.Rollback(); rerr != nil {
        err = fmt.Errorf("%w: %v", err, rerr)
    }
    return err
}

5.2 使用WithTx辅助函数

// WithTx 封装事务处理逻辑
func WithTx(ctx context.Context, client *ent.Client, fn func(tx *ent.Tx) error) error {
    tx, err := client.Tx(ctx)
    if err != nil {
        return err
    }
    
    defer func() {
        if v := recover(); v != nil {
            tx.Rollback()
            panic(v)
        }
    }()
    
    if err := fn(tx); err != nil {
        if rerr := tx.Rollback(); rerr != nil {
            err = fmt.Errorf("%w: rolling back transaction: %v", err, rerr)
        }
        return err
    }
    
    return tx.Commit()
}

// 使用示例
func CreateUserWithProfile(ctx context.Context, client *ent.Client) error {
    return WithTx(ctx, client, func(tx *ent.Tx) error {
        user, err := tx.User.
            Create().
            SetName("张三").
            SetEmail("zhangsan@example.com").
            Save(ctx)
        if err != nil {
            return err
        }
        
        _, err = tx.Profile.
            Create().
            SetBio("个人简介").
            SetUser(user).
            Save(ctx)
        return err
    })
}

5.3 嵌套事务和保存点

func NestedTransactionExample(ctx context.Context, client *ent.Client) error {
    tx, err := client.Tx(ctx)
    if err != nil {
        return err
    }
    
    // 创建第一个用户
    user1, err := tx.User.Create().SetName("用户1").SetEmail("user1@example.com").Save(ctx)
    if err != nil {
        return rollback(tx, err)
    }
    
    // 创建保存点
    savepoint := tx.Savepoint(ctx, "sp1")
    
    // 尝试创建第二个用户
    _, err = tx.User.Create().SetName("用户2").SetEmail("user1@example.com").Save(ctx) // 邮箱重复
    if err != nil {
        // 回滚到保存点,而不是整个事务
        if rerr := savepoint.Rollback(); rerr != nil {
            return rollback(tx, fmt.Errorf("%w: %v", err, rerr))
        }
        // 继续执行其他操作...
    }
    
    return tx.Commit()
}

六、Hooks(钩子)

Hooks允许你在实体创建、更新或删除之前或之后执行自定义逻辑。

6.1 Schema级别的Hooks

package schema

import (
    "context"
    "fmt"
    "time"

    "entgo.io/ent"
    "entgo.io/ent/schema/field"
    
    gen "myapp/ent"
    "myapp/ent/hook"
)

type User struct {
    ent.Schema
}

func (User) Hooks() []ent.Hook {
    return []ent.Hook{
        // 创建前的钩子
        hook.On(
            func(next ent.Mutator) ent.Mutator {
                return hook.UserFunc(func(ctx context.Context, m *gen.UserMutation) (ent.Value, error) {
                    // 设置默认值
                    if _, exists := m.Status(); !exists {
                        m.SetStatus("pending")
                    }
                    
                    // 记录日志
                    name, _ := m.Name()
                    fmt.Printf("Creating user: %s\n", name)
                    
                    return next.Mutate(ctx, m)
                })
            },
            ent.OpCreate,
        ),
        
        // 更新前的钩子
        hook.On(
            func(next ent.Mutator) ent.Mutator {
                return hook.UserFunc(func(ctx context.Context, m *gen.UserMutation) (ent.Value, error) {
                    // 自动更新updated_at字段
                    m.SetUpdatedAt(time.Now())
                    
                    return next.Mutate(ctx, m)
                })
            },
            ent.OpUpdate|ent.OpUpdateOne,
        ),
        
        // 删除前的钩子
        hook.On(
            func(next ent.Mutator) ent.Mutator {
                return hook.UserFunc(func(ctx context.Context, m *gen.UserMutation) (ent.Value, error) {
                    // 获取要删除的用户ID
                    id, exists := m.ID()
                    if exists {
                        fmt.Printf("Deleting user: %d\n", id)
                    }
                    
                    return next.Mutate(ctx, m)
                })
            },
            ent.OpDelete|ent
        ),
    }
}

6.2 运行时Hooks(全局Hooks)

package main

import (
    "context"
    "fmt"
    "log"
    "time"

    "myapp/ent"
    "myapp/ent/hook"
)

func main() {
    client, err := ent.Open("postgres", "...")
    if err != nil {
        log.Fatal(err)
    }
    
    // 注册全局钩子
    client.Use(
        // 日志钩子 - 记录所有操作
        func(next ent.Mutator) ent.Mutator {
            return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) {
                start := time.Now()
                
                // 执行操作
                v, err := next.Mutate(ctx, m)
                
                // 记录日志
                log.Printf(
                    "Op=%s, Type=%s, Duration=%s, Error=%v",
                    m.Op().String(),
                    m.Type(),
                    time.Since(start),
                    err,
                )
                
                return v, err
            })
        },
        
        // 审计钩子 - 记录操作者
        func(next ent.Mutator) ent.Mutator {
            return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) {
                // 从context获取当前用户
                userID, ok := ctx.Value("user_id").(int)
                if ok {
                    // 根据Mutation类型设置审计字段
                    if m.Op().Is(ent.OpCreate) {
                        if setter, ok := m.(interface{ SetCreatedBy(int) }); ok {
                            setter.SetCreatedBy(userID)
                        }
                    }
                    if m.Op().Is(ent.OpUpdate | ent.OpUpdateOne) {
                        if setter, ok := m.(interface{ SetUpdatedBy(int) }); ok {
                            setter.SetUpdatedBy(userID)
                        }
                    }
                }
                
                return next.Mutate(ctx, m)
            })
        },
    )
    
    // 针对特定类型的钩子
    client.User.Use(
        func(next ent.Mutator) ent.Mutator {
            return hook.UserFunc(func(ctx context.Context, m *ent.UserMutation) (ent.Value, error) {
                // 密码加密
                if password, exists := m.Password(); exists {
                    hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
                    m.SetPassword(string(hashedPassword))
                }
                return next.Mutate(ctx, m)
            })
        },
    )
}

6.3 条件Hooks

// 只在特定条件下执行的钩子
func ConditionalHook() ent.Hook {
    hk := func(next ent.Mutator) ent.Mutator {
        return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) {
            // 执行验证逻辑
            if um, ok := m.(*ent.UserMutation); ok {
                email, exists := um.Email()
                if exists && !isValidEmail(email) {
                    return nil, fmt.Errorf("invalid email format: %s", email)
                }
            }
            return next.Mutate(ctx, m)
        })
    }
    
    // 只在创建和更新时执行
    return hook.On(hk, ent.OpCreate|ent.OpUpdate|ent.OpUpdateOne)
}

// 排除某些操作
func ExcludeHook() ent.Hook {
    hk := func(next ent.Mutator) ent.Mutator {
        return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) {
            // 你的逻辑
            return next.Mutate(ctx, m)
        })
    }
    
    // 排除删除操作
    return hook.Unless(hk, ent.OpDelete|ent.OpDeleteOne)
}

// 只对特定字段变更时执行
func FieldChangeHook() ent.Hook {
    return hook.If(
        func(next ent.Mutator) ent.Mutator {
            return hook.UserFunc(func(ctx context.Context, m *ent.UserMutation) (ent.Value, error) {
                oldEmail, _ := m.OldEmail(ctx)
                newEmail, _ := m.Email()
                
                // 发送邮箱变更通知
                sendEmailChangeNotification(oldEmail, newEmail)
                
                return next.Mutate(ctx, m)
            })
        },
        hook.HasFields("email"), // 只有email字段变更时才执行
    )
}

七、Privacy(隐私/访问控制)

Ent的Privacy层允许定义细粒度的访问控制规则,主要用于数据访问控制和权限管理,它在数据库操作级别提供细粒度的访问控制。以下是主要使用场景:

核心使用场景

  1. 多租户隔离 确保用户只能访问自己的数据。比如在SaaS应用中,用户A不能查询或修改用户B的数据。
// 示例:用户只能看到自己的文章
func FilterTenantRule() privacy.QueryRule {
    return privacy.FilterFunc(func(ctx context.Context, f privacy.Filter) error {
        userID := getUserIDFromContext(ctx)
        f.Where(article.HasOwnerWith(user.ID(userID)))
        return nil
    })
}
  1. 基于角色的访问控制(RBAC) 根据用户角色限制操作权限。比如只有管理员可以删除某些资源,普通用户只能读取。
  2. 行级权限控制 根据数据的状态或属性控制访问。例如,已发布的文章所有人可见,草稿只有作者可见。
  3. 字段级权限 控制敏感字段的访问,如用户的手机号、邮箱等只有本人或管理员可见。
  4. 审计和合规 强制执行数据访问策略,确保符合GDPR等数据保护法规。

Privacy的两种规则类型

  • QueryRule: 控制读取操作(Query, Get)
  • MutationRule: 控制写入操作(Create, Update, Delete)

Privacy层的优势是将权限逻辑与业务代码分离,在ORM层统一处理,避免在每个查询中手动添加权限检查,减少安全漏洞风险。

7.1 配置Privacy

首先在生成配置中启用Privacy:

// ent/generate.go
package ent

//go:generate go run -mod=mod entgo.io/ent/cmd/ent generate --feature privacy ./schema

7.2 定义Privacy规则

package schema

import (
    "context"

    "entgo.io/ent"
    "entgo.io/ent/privacy"
    "entgo.io/ent/schema/field"
    
    "myapp/ent/predicate"
)

type User struct {
    ent.Schema
}

func (User) Fields() []ent.Field {
    return []ent.Field{
        field.String("name"),
        field.String("email"),
        field.Int("tenant_id"),
        field.Enum("role").Values("admin", "user", "guest"),
    }
}

func (User) Policy() ent.Policy {
    return privacy.Policy{
        // 查询规则
        Query: privacy.QueryPolicy{
            // 管理员可以查看所有用户
            privacy.QueryRuleFunc(func(ctx context.Context, q ent.Query) error {
                if isAdmin(ctx) {
                    return privacy.Allow
                }
                return privacy.Skip // 继续下一个规则
            }),
            
            // 普通用户只能查看同一租户的用户
            privacy.QueryRuleFunc(func(ctx context.Context, q ent.Query) error {
                tenantID := getTenantID(ctx)
                if tenantID == 0 {
                    return privacy.Deny
                }
                
                // 添加租户过滤条件
                userQuery := q.(*ent.UserQuery)
                userQuery.Where(user.TenantIDEQ(tenantID))
                return privacy.Allow
            }),
        },
        
        // 变更规则(创建、更新、删除)
        Mutation: privacy.MutationPolicy{
            // 只有管理员可以创建用户
            privacy.OnMutationOperation(
                privacy.MutationRuleFunc(func(ctx context.Context, m ent.Mutation) error {
                    if isAdmin(ctx) {
                        return privacy.Allow
                    }
                    return privacy.Deny
                }),
                ent.OpCreate,
            ),
            
            // 用户只能更新自己的信息
            privacy.OnMutationOperation(
                privacy.MutationRuleFunc(func(ctx context.Context, m ent.Mutation) error {
                    if isAdmin(ctx) {
                        return privacy.Allow
                    }
                    
                    um := m.(*ent.UserMutation)
                    id, exists := um.ID()
                    if !exists {
                        return privacy.Deny
                    }
                    
                    currentUserID := getCurrentUserID(ctx)
                    if id == currentUserID {
                        return privacy.Allow
                    }
                    
                    return privacy.Deny
                }),
                ent.OpUpdate|ent.OpUpdateOne,
            ),
            
            // 只有管理员可以删除用户
            privacy.OnMutationOperation(
                privacy.DenyMutationOperationRule(ent.OpDelete|ent.OpDeleteOne),
            ),
        },
    }
}

// 辅助函数
func isAdmin(ctx context.Context) bool {
    role, ok := ctx.Value("role").(string)
    return ok && role == "admin"
}

func getTenantID(ctx context.Context) int {
    tenantID, _ := ctx.Value("tenant_id").(int)
    return tenantID
}

func getCurrentUserID(ctx context.Context) int {
    userID, _ := ctx.Value("user_id").(int)
    return userID
}

7.3 多租户隔离规则

package rule

import (
    "context"

    "entgo.io/ent/privacy"

    "myapp/ent"
    "myapp/ent/predicate"
    "myapp/ent/user"
)

// FilterTenantRule 自动添加租户过滤
type FilterTenantRule struct {
    privacy.QueryMutationRule
}

func (f FilterTenantRule) EvalQuery(ctx context.Context, q ent.Query) error {
    tenantID := getTenantID(ctx)
    if tenantID == 0 {
        return privacy.Denyf("missing tenant information")
    }
    
    // 根据查询类型添加过滤条件
    switch q := q.(type) {
    case *ent.UserQuery:
        q.Where(user.TenantIDEQ(tenantID))
    case *ent.PostQuery:
        q.Where(post.TenantIDEQ(tenantID))
    // ... 其他实体
    }
    
    return privacy.Skip
}

func (f FilterTenantRule) EvalMutation(ctx context.Context, m ent.Mutation) error {
    tenantID := getTenantID(ctx)
    if tenantID == 0 {
        return privacy.Denyf("missing tenant information")
    }
    
    // 创建时自动设置租户ID
    if m.Op().Is(ent.OpCreate) {
        if setter, ok := m.(interface{ SetTenantID(int) }); ok {
            setter.SetTenantID(tenantID)
        }
    }
    
    return privacy.Skip
}

7.4 跳过Privacy检查

// 使用privacy.DecisionContext跳过检查
func AdminOperation(ctx context.Context, client *ent.Client) error {
    // 允许所有操作
    allowCtx := privacy.DecisionContext(ctx, privacy.Allow)
    
    users, err := client.User.Query().All(allowCtx)
    // ...
    return err
}

八、拦截器(Interceptors)

拦截器是Ent v0.11引入的功能,用于拦截查询操作。

package main

import (
    "context"
    "log"

    "entgo.io/ent"
    
    entgo "myapp/ent"
    "myapp/ent/intercept"
)

func main() {
    client, _ := entgo.Open("postgres", "...")
    
    // 查询拦截器
    client.Intercept(
        // 日志拦截器
        intercept.Func(func(ctx context.Context, q intercept.Query) error {
            log.Printf("Query: Type=%s", q.Type())
            return nil
        }),
        
        // 软删除过滤器 - 自动过滤已删除的记录
        intercept.TraverseFunc(func(ctx context.Context, q intercept.Query) error {
            // 检查是否需要跳过软删除过滤
            if skipSoftDelete(ctx) {
                return nil
            }
            
            // 添加deleted_at IS NULL条件
            switch q := q.(type) {
            case *entgo.UserQuery:
                q.Where(user.DeletedAtIsNil())
            case *entgo.PostQuery:
                q.Where(post.DeletedAtIsNil())
            }
            
            return nil
        }),
    )
    
    // 特定类型的拦截器
    client.User.Intercept(
        intercept.TraverseUser(func(ctx context.Context, q *entgo.UserQuery) error {
            // 用户特定的拦截逻辑
            log.Println("Querying users...")
            return nil
        }),
    )
}

func skipSoftDelete(ctx context.Context) bool {
    skip, _ := ctx.Value("skip_soft_delete").(bool)
    return skip
}

九、数据库迁移

9.1 自动迁移

import (
	"context"

	"entgo.io/ent/dialect/migrate"
	"myapp/ent"
)

func AutoMigration(ctx context.Context, client *ent.Client) error {
    // 基本迁移
    err := client.Schema.Create(ctx)
    
    // 带选项的迁移
    err = client.Schema.Create(ctx,
        // 创建外键约束
        migrate.WithForeignKeys(true),
        // 创建索引
        migrate.WithDropIndex(true),
        migrate.WithDropColumn(true),
        // 全局唯一ID(用于联邦查询)
        migrate.WithGlobalUniqueID(true),
    )
    
    return err
}

9.2 版本化迁移(Atlas集成)

Ent与Atlas集成提供版本化迁移能力:

# 安装Atlas CLI
curl -sSf https://atlasgo.sh | sh

# 或使用Go安装
go install ariga.io/atlas/cmd/atlas@latest

配置版本化迁移:

// ent/migrate/main.go
//go:build ignore

package main

import (
    "context"
    "log"
    "os"

    "myapp/ent/migrate"

    atlas "ariga.io/atlas/sql/migrate"
    "entgo.io/ent/dialect"
    "entgo.io/ent/dialect/sql/schema"
    _ "github.com/lib/pq"
)

func main() {
    ctx := context.Background()
    
    // 创建迁移目录
    dir, err := atlas.NewLocalDir("ent/migrate/migrations")
    if err != nil {
        log.Fatalf("failed creating atlas migration directory: %v", err)
    }
    
    // 迁移选项
    opts := []schema.MigrateOption{
        schema.WithDir(dir),
        schema.WithMigrationMode(schema.ModeReplay),
        schema.WithDialect(dialect.Postgres),
        schema.WithFormatter(atlas.DefaultFormatter),
    }
    
    // 生成迁移文件
    err = migrate.NamedDiff(ctx, os.Getenv("DATABASE_URL"), os.Args[1], opts...)
    if err != nil {
        log.Fatalf("failed generating migration file: %v", err)
    }

    log.Println("Migration file generated successfully")
}

使用Atlas命令:

# 生成迁移文件
go run -mod=mod ent/migrate/main.go add_user_table

# 应用迁移
atlas migrate apply \
  --dir "file://ent/migrate/migrations" \
  --url "postgres://localhost:5432/myapp?sslmode=disable"

# 查看迁移状态
atlas migrate status \
  --dir "file://ent/migrate/migrations" \
  --url "postgres://localhost:5432/myapp?sslmode=disable"

# 回滚迁移
atlas migrate down \
  --dir "file://ent/migrate/migrations" \
  --url "postgres://localhost:5432/myapp?sslmode=disable"

9.3 迁移钩子

func MigrationWithHooks(ctx context.Context, client *ent.Client) error {
    err := client.Schema.Create(
        ctx,
        migrate.WithHooks(func(next schema.Creator) schema.Creator {
            return schema.CreateFunc(func(ctx context.Context, tables ...*schema.Table) error {
                // 迁移前的操作
                log.Println("Starting migration...")
                
                // 执行迁移
                if err := next.Create(ctx, tables...); err != nil {
                    return err
                }
                
                // 迁移后的操作
                log.Println("Migration completed!")
                return nil
            })
        }),
    )
    
    return err
}

9.4 检查Schema差异

func CheckSchemaDiff(ctx context.Context, client *ent.Client) error {
    // 获取当前Schema和数据库Schema的差异
    changes, err := client.Schema.Diff(ctx)
    if err != nil {
        return err
    }
    
    for _, change := range changes {
        log.Printf("Change: %s\n", change)
    }
    
    return nil
}

十、高级功能

10.1 自定义ID生成

package schema

import (
    "entgo.io/ent"
    "entgo.io/ent/schema/field"
    
    "github.com/google/uuid"
    "github.com/sony/sonyflake"
)

var sf *sonyflake.Sonyflake

func init() {
    sf = sonyflake.NewSonyflake(sonyflake.Settings{})
}

// UUID作为ID
type User struct {
    ent.Schema
}

func (User) Fields() []ent.Field {
    return []ent.Field{
        field.UUID("id", uuid.UUID{}).
            Default(uuid.New).
            Immutable(),
        field.String("name"),
    }
}

// Snowflake ID
type Post struct {
    ent.Schema
}

func (Post) Fields() []ent.Field {
    return []ent.Field{
        field.Int64("id").
            DefaultFunc(func() int64 {
                id, _ := sf.NextID()
                return int64(id)
            }).
            Immutable(),
        field.String("title"),
    }
}

// 自定义字符串ID
type Order struct {
    ent.Schema
}

func (Order) Fields() []ent.Field {
    return []ent.Field{
        field.String("id").
            DefaultFunc(func() string {
                return fmt.Sprintf("ORD-%s", uuid.New().String()[:8])
            }).
            Immutable(),
    }
}

10.2 自定义谓词(Predicates)

package customsql

import (
    "entgo.io/ent/dialect/sql"
    
    "myapp/ent/predicate"
    "myapp/ent/user"
)

// FullTextSearch 全文搜索谓词
func FullTextSearch(term string) predicate.User {
    return predicate.User(func(s *sql.Selector) {
        s.Where(sql.ExprP(
            "to_tsvector('english', name || ' ' || COALESCE(bio, '')) @@ plainto_tsquery('english', $1)",
            term,
        ))
    })
}

// WithinDistance 地理位置范围查询
func WithinDistance(lat, lng, distanceKm float64) predicate.User {
    return predicate.User(func(s *sql.Selector) {
        s.Where(sql.ExprP(
            `ST_DWithin(
                location::geography,
                ST_SetSRID(ST_MakePoint($1, $2), 4326)::geography,
                $3
            )`,
            lng, lat, distanceKm*1000,
        ))
    })
}

// JSONContains JSON字段包含查询
func JSONContains(field, key, value string) predicate.User {
    return predicate.User(func(s *sql.Selector) {
        s.Where(sql.ExprP(
            fmt.Sprintf("%s->>$1 = $2", field),
            key, value,
        ))
    })
}

// 使用示例
func SearchUsers(ctx context.Context, client *ent.Client) ([]*ent.User, error) {
    return client.User.
        Query().
        Where(
            user.And(
                user.StatusEQ("active"),
                FullTextSearch("engineer"),
                WithinDistance(37.7749, -122.4194, 50), // 旧金山50公里内
            ),
        ).
        All(ctx)
}

10.3 原生SQL查询

package main

import (
    "context"
    "database/sql"
    
    "entgo.io/ent/dialect"
    entsql "entgo.io/ent/dialect/sql"
)

// 使用原生SQL查询
func RawSQLQuery(ctx context.Context, client *ent.Client) error {
    // 获取底层数据库连接
    var results []struct {
        ID    int    `json:"id"`
        Name  string `json:"name"`
        Count int    `json:"count"`
    }
    
    err := client.User.
        Query().
        Modify(func(s *entsql.Selector) {
            s.Select(
                s.C(user.FieldID),
                s.C(user.FieldName),
                entsql.As(entsql.Count("*"), "count"),
            ).
            From(s.Table()).
            LeftJoin(
                entsql.Table(post.Table),
            ).
            On(
                s.C(user.FieldID),
                entsql.Table(post.Table).C(post.FieldAuthorID),
            ).
            GroupBy(s.C(user.FieldID))
        }).
        Scan(ctx, &results)
    
    return err
}

// 直接执行原生SQL
func ExecRawSQL(ctx context.Context, client *ent.Client) error {
    // 获取驱动
    drv := client.Driver()
    
    // 执行查询
    rows, err := drv.Query(ctx, 
        "SELECT id, name FROM users WHERE status = $1 LIMIT $2", 
        []interface{}{"active", 10},
    )
    if err != nil {
        return err
    }
    defer rows.Close()
    
    for rows.Next() {
        var id int
        var name string
        if err := rows.Scan(&id, &name); err != nil {
            return err
        }
        fmt.Printf("User: %d - %s\n", id, name)
    }
    
    return rows.Err()
}

// 使用sql.DB直接操作
func DirectDBAccess(ctx context.Context, client *ent.Client) error {
    // 获取*sql.DB
    db := client.DB()
    
    // 执行原生查询
    rows, err := db.QueryContext(ctx,
        `SELECT u.id, u.name, COUNT(p.id) as post_count
         FROM users u
         LEFT JOIN posts p ON u.id = p.author_id
         GROUP BY u.id
         ORDER BY post_count DESC
         LIMIT 10`)
    if err != nil {
        return err
    }
    defer rows.Close()
    
    // 处理结果...
    return nil
}

10.4 分页器

package pagination

import (
    "context"
    "encoding/base64"
    "encoding/json"
    "fmt"
    
    "myapp/ent"
    "myapp/ent/user"
)

// PageInfo 分页信息
type PageInfo struct {
    HasNextPage     bool   `json:"has_next_page"`
    HasPreviousPage bool   `json:"has_previous_page"`
    StartCursor     string `json:"start_cursor,omitempty"`
    EndCursor       string `json:"end_cursor,omitempty"`
    TotalCount      int    `json:"total_count"`
}

// UserEdge 用户边缘
type UserEdge struct {
    Node   *ent.User `json:"node"`
    Cursor string    `json:"cursor"`
}

// UserConnection 用户连接(游标分页结果)
type UserConnection struct {
    Edges    []*UserEdge `json:"edges"`
    PageInfo *PageInfo   `json:"page_info"`
}

// Cursor 游标结构
type Cursor struct {
    ID        int   `json:"id"`
    Timestamp int64 `json:"ts"`
}

// EncodeCursor 编码游标
func EncodeCursor(c Cursor) string {
    data, _ := json.Marshal(c)
    return base64.StdEncoding.EncodeToString(data)
}

// DecodeCursor 解码游标
func DecodeCursor(s string) (Cursor, error) {
    var c Cursor
    data, err := base64.StdEncoding.DecodeString(s)
    if err != nil {
        return c, err
    }
    err = json.Unmarshal(data, &c)
    return c, err
}

// PaginateUsers 游标分页查询用户
func PaginateUsers(
    ctx context.Context,
    client *ent.Client,
    first *int,
    after *string,
    last *int,
    before *string,
) (*UserConnection, error) {
    query := client.User.Query().
        Where(user.StatusEQ("active")).
        Order(ent.Desc(user.FieldCreatedAt))
    
    // 获取总数
    totalCount, err := query.Clone().Count(ctx)
    if err != nil {
        return nil, err
    }
    
    // 处理after游标
    if after != nil {
        cursor, err := DecodeCursor(*after)
        if err != nil {
            return nil, err
        }
        query = query.Where(user.IDLT(cursor.ID))
    }
    
    // 处理before游标
    if before != nil {
        cursor, err := DecodeCursor(*before)
        if err != nil {
            return nil, err
        }
        query = query.Where(user.IDGT(cursor.ID))
    }
    
    // 设置限制(多取一个用于判断是否有下一页)
    limit := 10
    if first != nil {
        limit = *first
    }
    query = query.Limit(limit + 1)
    
    users, err := query.All(ctx)
    if err != nil {
        return nil, err
    }
    
    // 构建结果
    hasNextPage := len(users) > limit
    if hasNextPage {
        users = users[:limit]
    }
    
    edges := make([]*UserEdge, len(users))
    for i, u := range users {
        edges[i] = &UserEdge{
            Node: u,
            Cursor: EncodeCursor(Cursor{
                ID:        u.ID,
                Timestamp: u.CreatedAt.Unix(),
            }),
        }
    }
    
    var startCursor, endCursor string
    if len(edges) > 0 {
        startCursor = edges[0].Cursor
        endCursor = edges[len(edges)-1].Cursor
    }
    
    return &UserConnection{
        Edges: edges,
        PageInfo: &PageInfo{
            HasNextPage:     hasNextPage,
            HasPreviousPage: after != nil,
            StartCursor:     startCursor,
            EndCursor:       endCursor,
            TotalCount:      totalCount,
        },
    }, nil
}

// OffsetPagination 偏移分页
type OffsetPagination struct {
    Page     int `json:"page"`
    PageSize int `json:"page_size"`
    Total    int `json:"total"`
    Pages    int `json:"pages"`
}

type PaginatedUsers struct {
    Users      []*ent.User       `json:"users"`
    Pagination *OffsetPagination `json:"pagination"`
}

func PaginateUsersOffset(
    ctx context.Context,
    client *ent.Client,
    page, pageSize int,
) (*PaginatedUsers, error) {
    if page < 1 {
        page = 1
    }
    if pageSize < 1 || pageSize > 100 {
        pageSize = 10
    }
    
    query := client.User.Query().Where(user.StatusEQ("active"))
    
    // 获取总数
    total, err := query.Clone().Count(ctx)
    if err != nil {
        return nil, err
    }
    
    // 计算分页
    offset := (page - 1) * pageSize
    pages := (total + pageSize - 1) / pageSize
    
    users, err := query.
        Order(ent.Desc(user.FieldCreatedAt)).
        Offset(offset).
        Limit(pageSize).
        All(ctx)
    if err != nil {
        return nil, err
    }
    
    return &PaginatedUsers{
        Users: users,
        Pagination: &OffsetPagination{
            Page:     page,
            PageSize: pageSize,
            Total:    total,
            Pages:    pages,
        },
    }, nil
}

10.5 软删除实现

package schema

import (
    "context"
    "time"

    "entgo.io/ent"
    "entgo.io/ent/dialect/sql"
    "entgo.io/ent/schema/field"
    "entgo.io/ent/schema/mixin"
    
    gen "myapp/ent"
    "myapp/ent/hook"
    "myapp/ent/intercept"
)

// SoftDeleteMixin 软删除混入
type SoftDeleteMixin struct {
    mixin.Schema
}

func (SoftDeleteMixin) Fields() []ent.Field {
    return []ent.Field{
        field.Time("deleted_at").
            Optional().
            Nillable(),
    }
}

// SoftDeleteKey 用于context的键
type softDeleteKey struct{}

// SkipSoftDelete 返回跳过软删除过滤的context
func SkipSoftDelete(ctx context.Context) context.Context {
    return context.WithValue(ctx, softDeleteKey{}, true)
}

// 拦截器:查询时自动过滤已删除的记录
func (SoftDeleteMixin) Interceptors() []ent.Interceptor {
    return []ent.Interceptor{
        intercept.TraverseFunc(func(ctx context.Context, q intercept.Query) error {
            // 检查是否跳过软删除过滤
            if skip, _ := ctx.Value(softDeleteKey{}).(bool); skip {
                return nil
            }
            
            // 使用类型断言添加过滤条件
            sf, ok := q.(interface {
                WhereP(...func(*sql.Selector))
            })
            if !ok {
                return nil
            }
            
            sf.WhereP(func(s *sql.Selector) {
                s.Where(sql.IsNull(s.C("deleted_at")))
            })
            
            return nil
        }),
    }
}

// 钩子:将删除操作转换为软删除
func (SoftDeleteMixin) Hooks() []ent.Hook {
    return []ent.Hook{
        hook.On(
            func(next ent.Mutator) ent.Mutator {
                return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) {
                    // 检查是否强制硬删除
                    if skip, _ := ctx.Value(softDeleteKey{}).(bool); skip {
                        return next.Mutate(ctx, m)
                    }
                    
                    // 获取删除的实体类型和ID
                    mx, ok := m.(interface {
                        SetDeletedAt(time.Time)
                        WhereP(...func(*sql.Selector))
                        SetOp(ent.Op)
                        Client() *gen.Client
                    })
                    if !ok {
                        return next.Mutate(ctx, m)
                    }
                    
                    // 将删除转换为更新
                    mx.SetDeletedAt(time.Now())
                    mx.SetOp(ent.OpUpdate)
                    
                    return next.Mutate(ctx, m)
                })
            },
            ent.OpDelete|ent.OpDeleteOne,
        ),
    }
}

// User实体使用软删除
type User struct {
    ent.Schema
}

func (User) Mixin() []ent.Mixin {
    return []ent.Mixin{
        SoftDeleteMixin{},
    }
}

// 使用示例
func SoftDeleteExample(ctx context.Context, client *ent.Client) error {
    // 普通删除(软删除)
    err := client.User.DeleteOneID(1).Exec(ctx)
    
    // 查询时自动过滤已删除的记录
    users, err := client.User.Query().All(ctx)
    
    // 查询包括已删除的记录
    allUsers, err := client.User.Query().All(SkipSoftDelete(ctx))
    
    // 强制硬删除
    err = client.User.DeleteOneID(1).Exec(SkipSoftDelete(ctx))
    
    // 恢复已删除的记录
    err = client.User.
        UpdateOneID(1).
        ClearDeletedAt().
        Exec(SkipSoftDelete(ctx))
    
    return err
}

10.6 乐观锁实现

package schema

import (
    "context"
    "fmt"

    "entgo.io/ent"
    "entgo.io/ent/schema/field"
    "entgo.io/ent/schema/mixin"
    
    gen "myapp/ent"
    "myapp/ent/hook"
)

// OptimisticLockMixin 乐观锁混入
type OptimisticLockMixin struct {
    mixin.Schema
}

func (OptimisticLockMixin) Fields() []ent.Field {
    return []ent.Field{
        field.Int64("version").
            Default(1).
            Comment("乐观锁版本号"),
    }
}

func (OptimisticLockMixin) Hooks() []ent.Hook {
    return []ent.Hook{
        hook.On(
            func(next ent.Mutator) ent.Mutator {
                return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) {
                    // 获取旧版本号
                    mx, ok := m.(interface {
                        OldVersion(ctx context.Context) (int64, error)
                        SetVersion(int64)
                        WhereP(...func(*sql.Selector))
                    })
                    if !ok {
                        return next.Mutate(ctx, m)
                    }
                    
                    oldVersion, err := mx.OldVersion(ctx)
                    if err != nil {
                        return nil, err
                    }
                    
                    // 添加版本号条件
                    mx.WhereP(func(s *sql.Selector) {
                        s.Where(sql.EQ(s.C("version"), oldVersion))
                    })
                    
                    // 设置新版本号
                    mx.SetVersion(oldVersion + 1)
                    
                    // 执行更新
                    v, err := next.Mutate(ctx, m)
                    if err != nil {
                        return nil, err
                    }
                    
                    // 检查是否更新成功(返回的affected rows)
                    // 如果版本号不匹配,会返回0行被更新
                    
                    return v, nil
                })
            },
            ent.OpUpdate|ent.OpUpdateOne,
        ),
    }
}

// 使用乐观锁的更新函数
func UpdateWithOptimisticLock(ctx context.Context, client *ent.Client, userID int, name string) error {
    // 获取当前实体
    user, err := client.User.Get(ctx, userID)
    if err != nil {
        return err
    }
    
    // 尝试更新
    affected, err := client.User.
        Update().
        Where(
            user.IDEQ(userID),
            user.VersionEQ(user.Version),
        ).
        SetName(name).
        SetVersion(user.Version + 1).
        Save(ctx)
    
    if err != nil {
        return err
    }
    
    if affected == 0 {
        return fmt.Errorf("optimistic lock conflict: user %d has been modified", userID)
    }
    
    return nil
}

// 带重试的乐观锁更新
func UpdateWithRetry(ctx context.Context, client *ent.Client, userID int, updateFn func(*ent.User) error, maxRetries int) error {
    for i := 0; i < maxRetries; i++ {
        user, err := client.User.Get(ctx, userID)
        if err != nil {
            return err
        }
        
        // 应用更新
        if err := updateFn(user); err != nil {
            return err
        }
        
        // 尝试保存
        affected, err := client.User.
            Update().
            Where(
                user.IDEQ(userID),
                user.VersionEQ(user.Version),
            ).
            SetName(user.Name).
            SetVersion(user.Version + 1).
            Save(ctx)
        
        if err != nil {
            return err
        }
        
        if affected > 0 {
            return nil // 更新成功
        }
        
        // 版本冲突,重试
        log.Printf("Optimistic lock conflict, retry %d/%d", i+1, maxRetries)
    }
    
    return fmt.Errorf("failed to update after %d retries", maxRetries)
}

10.7 Redis缓存集成

📋 使用场景

为什么需要缓存?

  • 🚀 提升性能 - 减少数据库查询,降低响应时间
  • 💰 降低成本 - 减轻数据库负载,节省资源
  • 📈 提高并发 - 处理更多请求
  • 🔄 减少延迟 - 热点数据快速访问

适用场景:

  1. 读多写少的数据 - 用户信息、配置信息
  2. 计算密集型查询 - 统计数据、聚合结果
  3. 频繁访问的数据 - 热门商品、文章
  4. 会话数据 - 登录状态、购物车
  5. 临时数据 - 验证码、限流计数

实际业务例子:

  • 电商系统:商品详情、分类信息
  • 社交平台:用户资料、关注列表
  • 内容平台:文章详情、评论列表

安装依赖

go get github.com/redis/go-redis/v9

基础缓存实现

package cache

import (
    "context"
    "encoding/json"
    "fmt"
    "time"

    "github.com/redis/go-redis/v9"
    
    "myapp/ent"
)

type Cache struct {
    redis  *redis.Client
    client *ent.Client
}

func NewCache(redisClient *redis.Client, entClient *ent.Client) *Cache {
    return &Cache{
        redis:  redisClient,
        client: entClient,
    }
}

// 生成缓存键
func (c *Cache) userKey(id int) string {
    return fmt.Sprintf("user:%d", id)
}

// GetUser 从缓存获取用户,缓存未命中时从数据库加载
func (c *Cache) GetUser(ctx context.Context, id int) (*ent.User, error) {
    key := c.userKey(id)
    
    // 1. 尝试从Redis获取
    data, err := c.redis.Get(ctx, key).Bytes()
    if err == nil {
        // 缓存命中,反序列化
        var user ent.User
        if err := json.Unmarshal(data, &user); err == nil {
            return &user, nil
        }
    }
    
    // 2. 缓存未命中,从数据库查询
    user, err := c.client.User.Get(ctx, id)
    if err != nil {
        return nil, err
    }
    
    // 3. 写入缓存
    if data, err := json.Marshal(user); err == nil {
        c.redis.Set(ctx, key, data, 5*time.Minute)
    }
    
    return user, nil
}

// InvalidateUser 删除用户缓存
func (c *Cache) InvalidateUser(ctx context.Context, id int) error {
    return c.redis.Del(ctx, c.userKey(id)).Err()
}

// SetUser 设置用户缓存
func (c *Cache) SetUser(ctx context.Context, user *ent.User, ttl time.Duration) error {
    key := c.userKey(user.ID)
    data, err := json.Marshal(user)
    if err != nil {
        return err
    }
    return c.redis.Set(ctx, key, data, ttl).Err()
}

Cache-Aside模式(推荐)

package repository

import (
    "context"
    "encoding/json"
    "time"

    "github.com/redis/go-redis/v9"
    
    "myapp/ent"
    "myapp/ent/user"
)

type UserRepository struct {
    db    *ent.Client
    cache *redis.Client
}

func NewUserRepository(db *ent.Client, cache *redis.Client) *UserRepository {
    return &UserRepository{db: db, cache: cache}
}

// GetByID 通过ID获取用户(带缓存)
func (r *UserRepository) GetByID(ctx context.Context, id int) (*ent.User, error) {
    key := fmt.Sprintf("user:%d", id)
    
    // 1. 尝试从缓存获取
    val, err := r.cache.Get(ctx, key).Result()
    if err == nil {
        var u ent.User
        if err := json.Unmarshal([]byte(val), &u); err == nil {
            return &u, nil
        }
    }
    
    // 2. 缓存未命中,查询数据库
    u, err := r.db.User.Get(ctx, id)
    if err != nil {
        return nil, err
    }
    
    // 3. 写入缓存(异步,失败不影响结果)
    go func() {
        data, _ := json.Marshal(u)
        r.cache.Set(context.Background(), key, data, 10*time.Minute)
    }()
    
    return u, nil
}

// Update 更新用户(删除缓存)
func (r *UserRepository) Update(ctx context.Context, id int, updateFn func(*ent.UserUpdateOne) *ent.UserUpdateOne) (*ent.User, error) {
    // 1. 更新数据库
    u, err := updateFn(r.db.User.UpdateOneID(id)).Save(ctx)
    if err != nil {
        return nil, err
    }
    
    // 2. 删除缓存
    key := fmt.Sprintf("user:%d", id)
    r.cache.Del(ctx, key)
    
    return u, nil
}

// Delete 删除用户(删除缓存)
func (r *UserRepository) Delete(ctx context.Context, id int) error {
    // 1. 删除数据库记录
    if err := r.db.User.DeleteOneID(id).Exec(ctx); err != nil {
        return err
    }
    
    // 2. 删除缓存
    key := fmt.Sprintf("user:%d", id)
    r.cache.Del(ctx, key)
    
    return nil
}

// GetActiveUsers 获取活跃用户列表(带缓存)
func (r *UserRepository) GetActiveUsers(ctx context.Context, limit int) ([]*ent.User, error) {
    key := fmt.Sprintf("users:active:%d", limit)
    
    // 1. 尝试从缓存获取
    val, err := r.cache.Get(ctx, key).Result()
    if err == nil {
        var users []*ent.User
        if err := json.Unmarshal([]byte(val), &users); err == nil {
            return users, nil
        }
    }
    
    // 2. 查询数据库
    users, err := r.db.User.
        Query().
        Where(user.StatusEQ("active")).
        Limit(limit).
        All(ctx)
    if err != nil {
        return nil, err
    }
    
    // 3. 写入缓存
    go func() {
        data, _ := json.Marshal(users)
        r.cache.Set(context.Background(), key, data, 5*time.Minute)
    }()
    
    return users, nil
}

使用Ent Hook自动管理缓存

package schema

import (
    "context"
    "fmt"

    "github.com/redis/go-redis/v9"
    
    "entgo.io/ent"
    gen "myapp/ent"
    "myapp/ent/hook"
)

// 获取Redis客户端(从context或全局变量)
func getRedisClient(ctx context.Context) *redis.Client {
    if client, ok := ctx.Value("redis").(*redis.Client); ok {
        return client
    }
    return globalRedisClient // 或使用全局客户端
}

// User schema with cache hooks
type User struct {
    ent.Schema
}

func (User) Hooks() []ent.Hook {
    return []ent.Hook{
        // 创建后写入缓存
        hook.On(
            func(next ent.Mutator) ent.Mutator {
                return hook.UserFunc(func(ctx context.Context, m *gen.UserMutation) (ent.Value, error) {
                    // 执行创建
                    v, err := next.Mutate(ctx, m)
                    if err != nil {
                        return v, err
                    }
                    
                    // 获取创建的用户
                    user := v.(*ent.User)
                    
                    // 写入缓存
                    if rdb := getRedisClient(ctx); rdb != nil {
                        key := fmt.Sprintf("user:%d", user.ID)
                        data, _ := json.Marshal(user)
                        rdb.Set(ctx, key, data, 10*time.Minute)
                    }
                    
                    return v, nil
                })
            },
            ent.OpCreate,
        ),
        
        // 更新后删除缓存
        hook.On(
            func(next ent.Mutator) ent.Mutator {
                return hook.UserFunc(func(ctx context.Context, m *gen.UserMutation) (ent.Value, error) {
                    // 获取用户ID
                    id, exists := m.ID()
                    
                    // 执行更新
                    v, err := next.Mutate(ctx, m)
                    if err != nil {
                        return v, err
                    }
                    
                    // 删除缓存
                    if exists {
                        if rdb := getRedisClient(ctx); rdb != nil {
                            key := fmt.Sprintf("user:%d", id)
                            rdb.Del(ctx, key)
                        }
                    }
                    
                    return v, nil
                })
            },
            ent.OpUpdate|ent.OpUpdateOne,
        ),
        
        // 删除后清除缓存
        hook.On(
            func(next ent.Mutator) ent.Mutator {
                return hook.UserFunc(func(ctx context.Context, m *gen.UserMutation) (ent.Value, error) {
                    // 获取用户ID
                    id, exists := m.ID()
                    
                    // 执行删除
                    v, err := next.Mutate(ctx, m)
                    if err != nil {
                        return v, err
                    }
                    
                    // 删除缓存
                    if exists {
                        if rdb := getRedisClient(ctx); rdb != nil {
                            key := fmt.Sprintf("user:%d", id)
                            rdb.Del(ctx, key)
                        }
                    }
                    
                    return v, nil
                })
            },
            ent.OpDelete|ent.OpDeleteOne,
        ),
    }
}

缓存预热

package cache

import (
    "context"
    "encoding/json"
    "time"

    "github.com/redis/go-redis/v9"
    
    "myapp/ent"
    "myapp/ent/user"
)

// WarmUpCache 预热缓存
func WarmUpCache(ctx context.Context, db *ent.Client, cache *redis.Client) error {
    // 1. 加载热门用户
    users, err := db.User.
        Query().
        Where(user.StatusEQ("active")).
        Order(ent.Desc(user.FieldLoginCount)).
        Limit(100).
        All(ctx)
    if err != nil {
        return err
    }
    
    // 2. 批量写入缓存
    pipe := cache.Pipeline()
    for _, u := range users {
        key := fmt.Sprintf("user:%d", u.ID)
        data, _ := json.Marshal(u)
        pipe.Set(ctx, key, data, 10*time.Minute)
    }
    
    _, err = pipe.Exec(ctx)
    return err
}

缓存失效策略

package cache

import (
    "context"
    "fmt"
    "time"

    "github.com/redis/go-redis/v9"
)

type InvalidationStrategy struct {
    cache *redis.Client
}

// InvalidateUserRelatedCaches 删除用户相关的所有缓存
func (s *InvalidationStrategy) InvalidateUserRelatedCaches(ctx context.Context, userID int) error {
    patterns := []string{
        fmt.Sprintf("user:%d", userID),
        fmt.Sprintf("user:%d:posts:*", userID),
        fmt.Sprintf("user:%d:profile", userID),
        "users:active:*", // 列表缓存也要删除
    }
    
    for _, pattern := range patterns {
        keys, err := s.cache.Keys(ctx, pattern).Result()
        if err != nil {
            continue
        }
        
        if len(keys) > 0 {
            s.cache.Del(ctx, keys...)
        }
    }
    
    return nil
}

// SetWithTags 设置带标签的缓存
func (s *InvalidationStrategy) SetWithTags(ctx context.Context, key string, value interface{}, ttl time.Duration, tags ...string) error {
    // 1. 设置主键
    data, _ := json.Marshal(value)
    if err := s.cache.Set(ctx, key, data, ttl).Err(); err != nil {
        return err
    }
    
    // 2. 为每个标签建立索引
    for _, tag := range tags {
        tagKey := fmt.Sprintf("tag:%s", tag)
        s.cache.SAdd(ctx, tagKey, key)
        s.cache.Expire(ctx, tagKey, ttl)
    }
    
    return nil
}

// InvalidateByTag 通过标签删除缓存
func (s *InvalidationStrategy) InvalidateByTag(ctx context.Context, tag string) error {
    tagKey := fmt.Sprintf("tag:%s", tag)
    
    // 获取该标签下的所有key
    keys, err := s.cache.SMembers(ctx, tagKey).Result()
    if err != nil {
        return err
    }
    
    // 删除所有相关key
    if len(keys) > 0 {
        keys = append(keys, tagKey)
        s.cache.Del(ctx, keys...)
    }
    
    return nil
}

防止缓存穿透

package cache

import (
    "context"
    "encoding/json"
    "errors"
    "time"

    "github.com/redis/go-redis/v9"
    
    "myapp/ent"
)

var ErrNotFound = errors.New("not found")

// GetUserWithNullCache 使用空值缓存防止穿透
func GetUserWithNullCache(ctx context.Context, db *ent.Client, cache *redis.Client, id int) (*ent.User, error) {
    key := fmt.Sprintf("user:%d", id)
    
    // 1. 尝试从缓存获取
    val, err := cache.Get(ctx, key).Result()
    if err == nil {
        // 检查是否是空值缓存
        if val == "null" {
            return nil, ErrNotFound
        }
        
        var user ent.User
        if err := json.Unmarshal([]byte(val), &user); err == nil {
            return &user, nil
        }
    }
    
    // 2. 查询数据库
    user, err := db.User.Get(ctx, id)
    if err != nil {
        // 如果不存在,缓存一个空值(短时间)
        if ent.IsNotFound(err) {
            cache.Set(ctx, key, "null", 1*time.Minute)
        }
        return nil, err
    }
    
    // 3. 缓存正常数据
    data, _ := json.Marshal(user)
    cache.Set(ctx, key, data, 10*time.Minute)
    
    return user, nil
}

防止缓存雪崩

package cache

import (
    "math/rand"
    "time"
)

// RandomTTL 返回带随机抖动的TTL,防止同时过期
func RandomTTL(base time.Duration, jitter float64) time.Duration {
    if jitter <= 0 || jitter >= 1 {
        return base
    }
    
    // 添加±jitter范围的随机时间
    delta := time.Duration(float64(base) * jitter)
    offset := time.Duration(rand.Int63n(int64(delta*2))) - delta
    
    return base + offset
}

// 使用示例
func SetUserWithJitter(ctx context.Context, cache *redis.Client, user *ent.User) error {
    key := fmt.Sprintf("user:%d", user.ID)
    data, _ := json.Marshal(user)
    
    // 基础TTL为10分钟,添加±20%的抖动
    ttl := RandomTTL(10*time.Minute, 0.2)
    
    return cache.Set(ctx, key, data, ttl).Err()
}

防止缓存击穿(热点key)

package cache

import (
    "context"
    "encoding/json"
    "sync"
    "time"

    "github.com/redis/go-redis/v9"
    
    "myapp/ent"
)

type SingleFlight struct {
    mu    sync.Mutex
    calls map[string]*call
}

type call struct {
    wg  sync.WaitGroup
    val *ent.User
    err error
}

func NewSingleFlight() *SingleFlight {
    return &SingleFlight{
        calls: make(map[string]*call),
    }
}

// GetUser 使用singleflight防止缓存击穿
func (sf *SingleFlight) GetUser(ctx context.Context, db *ent.Client, cache *redis.Client, id int) (*ent.User, error) {
    key := fmt.Sprintf("user:%d", id)
    
    // 1. 尝试从缓存获取
    val, err := cache.Get(ctx, key).Result()
    if err == nil {
        var user ent.User
        if err := json.Unmarshal([]byte(val), &user); err == nil {
            return &user, nil
        }
    }
    
    // 2. 缓存未命中,使用singleflight
    sf.mu.Lock()
    if c, ok := sf.calls[key]; ok {
        // 已有其他goroutine在加载,等待结果
        sf.mu.Unlock()
        c.wg.Wait()
        return c.val, c.err
    }
    
    // 创建新的call
    c := &call{}
    c.wg.Add(1)
    sf.calls[key] = c
    sf.mu.Unlock()
    
    // 3. 从数据库加载
    c.val, c.err = db.User.Get(ctx, id)
    if c.err == nil {
        // 写入缓存
        data, _ := json.Marshal(c.val)
        cache.Set(ctx, key, data, 10*time.Minute)
    }
    
    // 4. 通知等待的goroutine
    c.wg.Done()
    
    sf.mu.Lock()
    delete(sf.calls, key)
    sf.mu.Unlock()
    
    return c.val, c.err
}

完整示例:带缓存的服务层

package service

import (
    "context"
    "time"

    "github.com/redis/go-redis/v9"
    
    "myapp/cache"
    "myapp/ent"
    "myapp/ent/user"
)

type UserService struct {
    db         *ent.Client
    cache      *redis.Client
    repository *cache.UserRepository
    sf         *cache.SingleFlight
}

func NewUserService(db *ent.Client, cache *redis.Client) *UserService {
    return &UserService{
        db:         db,
        cache:      cache,
        repository: cache.NewUserRepository(db, cache),
        sf:         cache.NewSingleFlight(),
    }
}

// GetUser 获取用户(使用所有缓存策略)
func (s *UserService) GetUser(ctx context.Context, id int) (*ent.User, error) {
    // 使用singleflight + 空值缓存 + 随机TTL
    return s.sf.GetUser(ctx, s.db, s.cache, id)
}

// UpdateUser 更新用户(自动失效缓存)
func (s *UserService) UpdateUser(ctx context.Context, id int, name string) (*ent.User, error) {
    return s.repository.Update(ctx, id, func(u *ent.UserUpdateOne) *ent.UserUpdateOne {
        return u.SetName(name)
    })
}

// GetActiveUsers 获取活跃用户列表
func (s *UserService) GetActiveUsers(ctx context.Context, limit int) ([]*ent.User, error) {
    return s.repository.GetActiveUsers(ctx, limit)
}

// 启动时预热缓存
func (s *UserService) WarmUp(ctx context.Context) error {
    return cache.WarmUpCache(ctx, s.db, s.cache)
}

缓存最佳实践总结

1. 缓存策略选择:

  • Cache-Aside(旁路缓存):最常用,应用负责缓存管理
  • Read-Through:读穿透,由缓存层负责加载数据
  • Write-Through:写穿透,同步写缓存和数据库
  • Write-Behind:异步写,先写缓存,异步写数据库

2. 缓存更新策略:

  • 更新时删除缓存(推荐)
  • 更新时更新缓存(数据一致性要求高时)
  • 定时失效(适合计算密集型数据)

3. 三大问题解决:

  • 缓存穿透:布隆过滤器 + 空值缓存
  • 缓存击穿:Singleflight + 互斥锁
  • 缓存雪崩:随机TTL + 多级缓存

4. 监控指标:

  • 缓存命中率(Hit Rate)
  • 缓存延迟(Latency)
  • 内存使用量
  • 过期Key数量

十一、GraphQL集成

Ent原生支持GraphQL,可以自动生成GraphQL schema和resolver。

11.1 配置GraphQL生成

// ent/entc.go
//go:build ignore

package main

import (
    "log"

    "entgo.io/contrib/entgql"
    "entgo.io/ent/entc"
    "entgo.io/ent/entc/gen"
)

func main() {
    ex, err := entgql.NewExtension(
        // 生成GQL schema
        entgql.WithSchemaGenerator(),
        entgql.WithSchemaPath("graph/ent.graphql"),
        // 配置自动生成的类型
        entgql.WithConfigPath("gqlgen.yml"),
        // 启用Relay规范
        entgql.WithRelaySpec(true),
        // Node接口
        entgql.WithNodeDescriptor(true),
    )
    if err != nil {
        log.Fatalf("creating entgql extension: %v", err)
    }
    
    opts := []entc.Option{
        entc.Extensions(ex),
    }
    
    if err := entc.Generate("./ent/schema", &gen.Config{}, opts...); err != nil {
        log.Fatalf("running ent codegen: %v", err)
    }
}

11.2 Schema添加GraphQL注解

package schema

import (
    "entgo.io/contrib/entgql"
    "entgo.io/ent"
    "entgo.io/ent/schema"
    "entgo.io/ent/schema/edge"
    "entgo.io/ent/schema/field"
)

type User struct {
    ent.Schema
}

func (User) Fields() []ent.Field {
    return []ent.Field{
        field.String("name").
            Annotations(
                entgql.OrderField("NAME"),
            ),
        field.String("email").
            Annotations(
                entgql.OrderField("EMAIL"),
            ),
        field.Time("created_at").
            Annotations(
                entgql.OrderField("CREATED_AT"),
            ),
        field.Enum("status").
            Values("active", "inactive").
            Annotations(
                entgql.Type("UserStatus"),
            ),
    }
}

func (User) Edges() []ent.Edge {
    return []ent.Edge{
        edge.To("posts", Post.Type).
            Annotations(
                entgql.RelayConnection(),
                entgql.OrderField("POSTS_COUNT"),
            ),
    }
}

func (User) Annotations() []schema.Annotation {
    return []schema.Annotation{
        entgql.RelayConnection(),
        entgql.QueryField(),
        entgql.Mutations(
            entgql.MutationCreate(),
            entgql.MutationUpdate(),
        ),
    }
}

11.3 GraphQL Resolver实现

package resolver

import (
    "context"

    "myapp/ent"
    "myapp/graph/generated"
)

type Resolver struct {
    Client *ent.Client
}

func NewSchema(client *ent.Client) *generated.Config {
    return &generated.Config{
        Resolvers: &Resolver{Client: client},
    }
}

// Query resolver
type queryResolver struct{ *Resolver }

func (r *Resolver) Query() generated.QueryResolver {
    return &queryResolver{r}
}

func (r *queryResolver) Node(ctx context.Context, id int) (ent.Noder, error) {
    return r.Client.Noder(ctx, id)
}

func (r *queryResolver) Nodes(ctx context.Context, ids []int) ([]ent.Noder, error) {
    return r.Client.Noders(ctx, ids)
}

func (r *queryResolver) Users(
    ctx context.Context,
    after *ent.Cursor,
    first *int,
    before *ent.Cursor,
    last *int,
    orderBy *ent.UserOrder,
    where *ent.UserWhereInput,
) (*ent.UserConnection, error) {
    return r.Client.User.
        Query().
        Paginate(ctx, after, first, before, last,
            ent.WithUserOrder(orderBy),
            ent.WithUserFilter(where.Filter),
        )
}

// Mutation resolver
type mutationResolver struct{ *Resolver }

func (r *Resolver) Mutation() generated.MutationResolver {
    return &mutationResolver{r}
}

func (r *mutationResolver) CreateUser(ctx context.Context, input ent.CreateUserInput) (*ent.User, error) {
    return r.Client.User.Create().SetInput(input).Save(ctx)
}

func (r *mutationResolver) UpdateUser(ctx context.Context, id int, input ent.UpdateUserInput) (*ent.User, error) {
    return r.Client.User.UpdateOneID(id).SetInput(input).Save(ctx)
}

func (r *mutationResolver) DeleteUser(ctx context.Context, id int) (*ent.User, error) {
    user, err := r.Client.User.Get(ctx, id)
    if err != nil {
        return nil, err
    }
    
    if err := r.Client.User.DeleteOne(user).Exec(ctx); err != nil {
        return nil, err
    }
    
    return user, nil
}

11.4 启动GraphQL服务器

package main

import (
    "log"
    "net/http"

    "github.com/99designs/gqlgen/graphql/handler"
    "github.com/99designs/gqlgen/graphql/playground"
    
    "myapp/ent"
    "myapp/graph/generated"
    "myapp/graph/resolver"
)

func main() {
    client, err := ent.Open("postgres", "...")
    if err != nil {
        log.Fatal(err)
    }
    defer client.Close()
    
    // 创建GraphQL服务器
    srv := handler.NewDefaultServer(
        generated.NewExecutableSchema(resolver.NewSchema(client)),
    )
    
    // 添加中间件
    srv.Use(entgql.Transactioner{TxOpener: client})
    
    http.Handle("/", playground.Handler("GraphQL playground", "/query"))
    http.Handle("/query", srv)
    
    log.Println("GraphQL server running at http://localhost:8080/")
    log.Fatal(http.ListenAndServe(":8080", nil))
}

十二、gRPC集成

12.1 定义Proto文件

// proto/user.proto
syntax = "proto3";

package user;
option go_package = "myapp/proto/user";

service UserService {
    rpc CreateUser(CreateUserRequest) returns (User);
    rpc GetUser(GetUserRequest) returns (User);
    rpc ListUsers(ListUsersRequest) returns (ListUsersResponse);
    rpc UpdateUser(UpdateUserRequest) returns (User);
    rpc DeleteUser(DeleteUserRequest) returns (DeleteUserResponse);
}

message User {
    int64 id = 1;
    string name = 2;
    string email = 3;
    string status = 4;
    int64 created_at = 5;
}

message CreateUserRequest {
    string name = 1;
    string email = 2;
}

message GetUserRequest {
    int64 id = 1;
}

message ListUsersRequest {
    int32 page = 1;
    int32 page_size = 2;
    string status = 3;
}

message ListUsersResponse {
    repeated User users = 1;
    int32 total = 2;
    int32 page = 3;
    int32 pages = 4;
}

message UpdateUserRequest {
    int64 id = 1;
    optional string name = 2;
    optional string email = 3;
    optional string status = 4;
}

message DeleteUserRequest {
    int64 id = 1;
}

message DeleteUserResponse {
    bool success = 1;
}

12.2 实现gRPC服务

package grpc

import (
    "context"
    "time"

    "google.golang.org/grpc/codes"
    "google.golang.org/grpc/status"
    
    "myapp/ent"
    "myapp/ent/user"
    pb "myapp/proto/user"
)

type UserServiceServer struct {
    pb.UnimplementedUserServiceServer
    client *ent.Client
}

func NewUserServiceServer(client *ent.Client) *UserServiceServer {
    return &UserServiceServer{client: client}
}

// 将Ent User转换为Proto User
func toProtoUser(u *ent.User) *pb.User {
    return &pb.User{
        Id:        int64(u.ID),
        Name:      u.Name,
        Email:     u.Email,
        Status:    u.Status.String(),
        CreatedAt: u.CreatedAt.Unix(),
    }
}

// CreateUser 创建用户
func (s *UserServiceServer) CreateUser(ctx context.Context, req *pb.CreateUserRequest) (*pb.User, error) {
    // 参数验证
    if req.Name == "" {
        return nil, status.Error(codes.InvalidArgument, "name is required")
    }
    if req.Email == "" {
        return nil, status.Error(codes.InvalidArgument, "email is required")
    }
    
    // 创建用户
    u, err := s.client.User.
        Create().
        SetName(req.Name).
        SetEmail(req.Email).
        SetStatus(user.StatusActive).
        Save(ctx)
    
    if err != nil {
        if ent.IsConstraintError(err) {
            return nil, status.Error(codes.AlreadyExists, "email already exists")
        }
        return nil, status.Error(codes.Internal, err.Error())
    }
    
    return toProtoUser(u), nil
}

// GetUser 获取用户
func (s *UserServiceServer) GetUser(ctx context.Context, req *pb.GetUserRequest) (*pb.User, error) {
    u, err := s.client.User.Get(ctx, int(req.Id))
    if err != nil {
        if ent.IsNotFound(err) {
            return nil, status.Error(codes.NotFound, "user not found")
        }
        return nil, status.Error(codes.Internal, err.Error())
    }
    
    return toProtoUser(u), nil
}

// ListUsers 列出用户
func (s *UserServiceServer) ListUsers(ctx context.Context, req *pb.ListUsersRequest) (*pb.ListUsersResponse, error) {
    // 设置默认分页参数
    page := int(req.Page)
    if page < 1 {
        page = 1
    }
    pageSize := int(req.PageSize)
    if pageSize < 1 || pageSize > 100 {
        pageSize = 10
    }
    
    // 构建查询
    query := s.client.User.Query()
    
    // 应用过滤条件
    if req.Status != "" {
        query = query.Where(user.StatusEQ(user.Status(req.Status)))
    }
    
    // 获取总数
    total, err := query.Clone().Count(ctx)
    if err != nil {
        return nil, status.Error(codes.Internal, err.Error())
    }
    
    // 分页查询
    users, err := query.
        Order(ent.Desc(user.FieldCreatedAt)).
        Offset((page - 1) * pageSize).
        Limit(pageSize).
        All(ctx)
    
    if err != nil {
        return nil, status.Error(codes.Internal, err.Error())
    }
    
    // 转换结果
    pbUsers := make([]*pb.User, len(users))
    for i, u := range users {
        pbUsers[i] = toProtoUser(u)
    }
    
    return &pb.ListUsersResponse{
        Users: pbUsers,
        Total: int32(total),
        Page:  int32(page),
        Pages: int32((total + pageSize - 1) / pageSize),
    }, nil
}

// UpdateUser 更新用户
func (s *UserServiceServer) UpdateUser(ctx context.Context, req *pb.UpdateUserRequest) (*pb.User, error) {
    // 检查用户是否存在
    exists, err := s.client.User.Query().Where(user.IDEQ(int(req.Id))).Exist(ctx)
    if err != nil {
        return nil, status.Error(codes.Internal, err.Error())
    }
    if !exists {
        return nil, status.Error(codes.NotFound, "user not found")
    }
    
    // 构建更新
    update := s.client.User.UpdateOneID(int(req.Id))
    
    if req.Name != nil {
        update = update.SetName(*req.Name)
    }
    if req.Email != nil {
        update = update.SetEmail(*req.Email)
    }
    if req.Status != nil {
        update = update.SetStatus(user.Status(*req.Status))
    }
    
    u, err := update.Save(ctx)
    if err != nil {
        if ent.IsConstraintError(err) {
            return nil, status.Error(codes.AlreadyExists, "email already exists")
        }
        return nil, status.Error(codes.Internal, err.Error())
    }
    
    return toProtoUser(u), nil
}

// DeleteUser 删除用户
func (s *UserServiceServer) DeleteUser(ctx context.Context, req *pb.DeleteUserRequest) (*pb.DeleteUserResponse, error) {
    err := s.client.User.DeleteOneID(int(req.Id)).Exec(ctx)
    if err != nil {
        if ent.IsNotFound(err) {
            return nil, status.Error(codes.NotFound, "user not found")
        }
        return nil, status.Error(codes.Internal, err.Error())
    }
    
    return &pb.DeleteUserResponse{Success: true}, nil
}

12.3 启动gRPC服务器

package main

import (
    "log"
    "net"

    "google.golang.org/grpc"
    "google.golang.org/grpc/reflection"
    
    "myapp/ent"
    grpcserver "myapp/grpc"
    pb "myapp/proto/user"
)

func main() {
    // 连接数据库
    client, err := ent.Open("postgres", "...")
    if err != nil {
        log.Fatalf("failed opening connection to postgres: %v", err)
    }
    defer client.Close()
    
    // 运行迁移
    if err := client.Schema.Create(context.Background()); err != nil {
        log.Fatalf("failed creating schema: %v", err)
    }
    
    // 创建gRPC服务器
    server := grpc.NewServer(
        grpc.UnaryInterceptor(loggingInterceptor),
    )
    
    // 注册服务
    pb.RegisterUserServiceServer(server, grpcserver.NewUserServiceServer(client))
    
    // 注册反射服务(用于grpcurl等工具)
    reflection.Register(server)
    
    // 启动服务器
    lis, err := net.Listen("tcp", ":50051")
    if err != nil {
        log.Fatalf("failed to listen: %v", err)
    }
    
    log.Println("gRPC server running at :50051")
    if err := server.Serve(lis); err != nil {
        log.Fatalf("failed to serve: %v", err)
    }
}

// 日志拦截器
func loggingInterceptor(
    ctx context.Context,
    req interface{},
    info *grpc.UnaryServerInfo,
    handler grpc.UnaryHandler,
) (interface{}, error) {
    start := time.Now()
    
    resp, err := handler(ctx, req)
    
    log.Printf(
        "method=%s duration=%s error=%v",
        info.FullMethod,
        time.Since(start),
        err,
    )
    
    return resp, err
}

十三、测试

13.1 单元测试

package ent_test

import (
    "context"
    "testing"

    "myapp/ent"
    "myapp/ent/enttest"
    "myapp/ent/user"
    
    _ "github.com/mattn/go-sqlite3"
    "github.com/stretchr/testify/assert"
    "github.com/stretchr/testify/require"
)

func TestUserCRUD(t *testing.T) {
    // 创建测试客户端(使用SQLite内存数据库)
    client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
    defer client.Close()
    
    ctx := context.Background()
    
    // 测试创建
    t.Run("Create", func(t *testing.T) {
        u, err := client.User.
            Create().
            SetName("测试用户").
            SetEmail("test@example.com").
            SetAge(25).
            Save(ctx)
        
        require.NoError(t, err)
        assert.NotZero(t, u.ID)
        assert.Equal(t, "测试用户", u.Name)
        assert.Equal(t, "test@example.com", u.Email)
    })
    
    // 测试查询
    t.Run("Query", func(t *testing.T) {
        u, err := client.User.
            Query().
            Where(user.EmailEQ("test@example.com")).
            Only(ctx)
        
        require.NoError(t, err)
        assert.Equal(t, "测试用户", u.Name)
    })
    
    // 测试更新
    t.Run("Update", func(t *testing.T) {
        affected, err := client.User.
            Update().
            Where(user.EmailEQ("test@example.com")).
            SetName("更新后的名字").
            Save(ctx)
        
        require.NoError(t, err)
        assert.Equal(t, 1, affected)
        
        // 验证更新
        u, err := client.User.
            Query().
            Where(user.EmailEQ("test@example.com")).
            Only(ctx)
        
        require.NoError(t, err)
        assert.Equal(t, "更新后的名字", u.Name)
    })
    
    // 测试删除
    t.Run("Delete", func(t *testing.T) {
        affected, err := client.User.
            Delete().
            Where(user.EmailEQ("test@example.com")).
            Exec(ctx)
        
        require.NoError(t, err)
        assert.Equal(t, 1, affected)
        
        // 验证删除
        exists, err := client.User.
            Query().
            Where(user.EmailEQ("test@example.com")).
            Exist(ctx)
        
        require.NoError(t, err)
        assert.False(t, exists)
    })
}

13.2 关系测试

func TestUserPostRelation(t *testing.T) {
    client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
    defer client.Close()
    
    ctx := context.Background()
    
    // 创建用户和文章
    t.Run("CreateWithRelation", func(t *testing.T) {
        // 创建用户
        u, err := client.User.
            Create().
            SetName("作者").
            SetEmail("author@example.com").
            Save(ctx)
        require.NoError(t, err)
        
        // 创建文章
        p1, err := client.Post.
            Create().
            SetTitle("文章1").
            SetContent("内容1").
            SetAuthor(u).
            Save(ctx)
        require.NoError(t, err)
        
        p2, err := client.Post.
            Create().
            SetTitle("文章2").
            SetContent("内容2").
            SetAuthorID(u.ID).
            Save(ctx)
        require.NoError(t, err)
        
        // 验证关系
        posts, err := u.QueryPosts().All(ctx)
        require.NoError(t, err)
        assert.Len(t, posts, 2)
        
        // 验证反向关系
        author, err := p1.QueryAuthor().Only(ctx)
        require.NoError(t, err)
        assert.Equal(t, u.ID, author.ID)
    })
    
    // 测试预加载
    t.Run("EagerLoading", func(t *testing.T) {
        users, err := client.User.
            Query().
            WithPosts().
            All(ctx)
        
        require.NoError(t, err)
        assert.NotEmpty(t, users)
        
        for _, u := range users {
            // Edges.Posts已经被预加载
            assert.NotNil(t, u.Edges.Posts)
        }
    })
}

13.3 事务测试

func TestTransaction(t *testing.T) {
    client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
    defer client.Close()
    
    ctx := context.Background()
    
    t.Run("CommitTransaction", func(t *testing.T) {
        tx, err := client.Tx(ctx)
        require.NoError(t, err)
        
        _, err = tx.User.
            Create().
            SetName("事务用户").
            SetEmail("tx@example.com").
            Save(ctx)
        require.NoError(t, err)
        
        err = tx.Commit()
        require.NoError(t, err)
        
        // 验证提交成功
        exists, err := client.User.
            Query().
            Where(user.EmailEQ("tx@example.com")).
            Exist(ctx)
        require.NoError(t, err)
        assert.True(t, exists)
    })
    
    t.Run("RollbackTransaction", func(t *testing.T) {
        tx, err := client.Tx(ctx)
        require.NoError(t, err)
        
        _, err = tx.User.
            Create().
            SetName("回滚用户").
            SetEmail("rollback@example.com").
            Save(ctx)
        require.NoError(t, err)

        err = tx.Rollback()
        require.NoError(t, err)
        
        // 验证回滚成功
        exists, err := client.User.
            Query().
            Where(user.EmailEQ("rollback@example.com")).
            Exist(ctx)
        require.NoError(t, err)
        assert.False(t, exists)
    })
}

13.4 Mock测试

package service_test

import (
    "context"
    "testing"

    "github.com/stretchr/testify/assert"
    "github.com/stretchr/testify/mock"
    
    "myapp/ent"
)

// MockUserRepository 用户仓库Mock
type MockUserRepository struct {
    mock.Mock
}

func (m *MockUserRepository) Create(ctx context.Context, name, email string) (*ent.User, error) {
    args := m.Called(ctx, name, email)
    if args.Get(0) == nil {
        return nil, args.Error(1)
    }
    return args.Get(0).(*ent.User), args.Error(1)
}

func (m *MockUserRepository) GetByID(ctx context.Context, id int) (*ent.User, error) {
    args := m.Called(ctx, id)
    if args.Get(0) == nil {
        return nil, args.Error(1)
    }
    return args.Get(0).(*ent.User), args.Error(1)
}

func (m *MockUserRepository) GetByEmail(ctx context.Context, email string) (*ent.User, error) {
    args := m.Called(ctx, email)
    if args.Get(0) == nil {
        return nil, args.Error(1)
    }
    return args.Get(0).(*ent.User), args.Error(1)
}

// UserService 用户服务
type UserService struct {
    repo UserRepository
}

type UserRepository interface {
    Create(ctx context.Context, name, email string) (*ent.User, error)
    GetByID(ctx context.Context, id int) (*ent.User, error)
    GetByEmail(ctx context.Context, email string) (*ent.User, error)
}

func NewUserService(repo UserRepository) *UserService {
    return &UserService{repo: repo}
}

func (s *UserService) Register(ctx context.Context, name, email string) (*ent.User, error) {
    // 检查邮箱是否已存在
    existing, err := s.repo.GetByEmail(ctx, email)
    if err == nil && existing != nil {
        return nil, fmt.Errorf("email already exists")
    }
    
    return s.repo.Create(ctx, name, email)
}

// 测试
func TestUserService_Register(t *testing.T) {
    ctx := context.Background()
    
    t.Run("Success", func(t *testing.T) {
        mockRepo := new(MockUserRepository)
        service := NewUserService(mockRepo)
        
        expectedUser := &ent.User{
            ID:    1,
            Name:  "测试用户",
            Email: "test@example.com",
        }
        
        // 设置期望
        mockRepo.On("GetByEmail", ctx, "test@example.com").Return(nil, fmt.Errorf("not found"))
        mockRepo.On("Create", ctx, "测试用户", "test@example.com").Return(expectedUser, nil)
        
        // 执行
        user, err := service.Register(ctx, "测试用户", "test@example.com")
        
        // 断言
        assert.NoError(t, err)
        assert.Equal(t, expectedUser, user)
        mockRepo.AssertExpectations(t)
    })
    
    t.Run("EmailExists", func(t *testing.T) {
        mockRepo := new(MockUserRepository)
        service := NewUserService(mockRepo)
        
        existingUser := &ent.User{
            ID:    1,
            Email: "existing@example.com",
        }
        
        mockRepo.On("GetByEmail", ctx, "existing@example.com").Return(existingUser, nil)
        
        user, err := service.Register(ctx, "新用户", "existing@example.com")
        
        assert.Error(t, err)
        assert.Nil(t, user)
        assert.Contains(t, err.Error(), "email already exists")
        mockRepo.AssertExpectations(t)
    })
}

13.5 Hooks测试

func TestHooks(t *testing.T) {
    client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
    defer client.Close()
    
    ctx := context.Background()
    
    // 记录Hook被调用
    var hookCalled bool
    
    // 添加测试Hook
    client.User.Use(func(next ent.Mutator) ent.Mutator {
        return hook.UserFunc(func(ctx context.Context, m *ent.UserMutation) (ent.Value, error) {
            hookCalled = true
            return next.Mutate(ctx, m)
        })
    })
    
    t.Run("HookIsCalled", func(t *testing.T) {
        hookCalled = false
        
        _, err := client.User.
            Create().
            SetName("Hook测试").
            SetEmail("hook@example.com").
            Save(ctx)
        
        require.NoError(t, err)
        assert.True(t, hookCalled, "Hook should be called")
    })
}

十四、性能优化

14.1 批量操作优化

package optimization

import (
    "context"

    "myapp/ent"
)

// 批量创建(使用CreateBulk)
func BulkCreate(ctx context.Context, client *ent.Client, users []UserInput) error {
    builders := make([]*ent.UserCreate, len(users))
    for i, u := range users {
        builders[i] = client.User.
            Create().
            SetName(u.Name).
            SetEmail(u.Email)
    }
    
    _, err := client.User.CreateBulk(builders...).Save(ctx)
    return err
}

// 批量更新
func BulkUpdate(ctx context.Context, client *ent.Client, ids []int, status string) error {
    _, err := client.User.
        Update().
        Where(user.IDIn(ids...)).
        SetStatus(status).
        Save(ctx)
    return err
}

// 批量删除
func BulkDelete(ctx context.Context, client *ent.Client, ids []int) error {
    _, err := client.User.
        Delete().
        Where(user.IDIn(ids...)).
        Exec(ctx)
    return err
}

// 分批处理大量数据
func ProcessInBatches(ctx context.Context, client *ent.Client, batchSize int, fn func([]*ent.User) error) error {
    var lastID int
    
    for {
        users, err := client.User.
            Query().
            Where(user.IDGT(lastID)).
            Order(ent.Asc(user.FieldID)).
            Limit(batchSize).
            All(ctx)
        
        if err != nil {
            return err
        }
        
        if len(users) == 0 {
            break
        }
        
        if err := fn(users); err != nil {
            return err
        }
        
        lastID = users[len(users)-1].ID
    }
    
    return nil
}

14.2 查询优化

package optimization

import (
    "context"

    "myapp/ent"
    "myapp/ent/user"
    "myapp/ent/post"
)

// 只选择需要的字段
func SelectSpecificFields(ctx context.Context, client *ent.Client) ([]struct {
    ID    int
    Name  string
    Email string
}, error) {
    var results []struct {
        ID    int    `json:"id"`
        Name  string `json:"name"`
        Email string `json:"email"`
    }
    
    err := client.User.
        Query().
        Select(user.FieldID, user.FieldName, user.FieldEmail).
        Scan(ctx, &results)
    
    return results, err
}

// 避免N+1问题 - 使用预加载
func EagerLoadingExample(ctx context.Context, client *ent.Client) ([]*ent.User, error) {
    return client.User.
        Query().
        WithPosts(func(q *ent.PostQuery) {
            q.Where(post.StatusEQ("published"))
            q.Order(ent.Desc(post.FieldCreatedAt))
            q.Limit(5)
            q.WithTags() // 嵌套预加载
        }).
        WithProfile().
        All(ctx)
}

// 使用索引优化查询
func QueryWithIndex(ctx context.Context, client *ent.Client, email string) (*ent.User, error) {
    // 确保email字段有索引
    return client.User.
        Query().
        Where(user.EmailEQ(email)).
        Only(ctx)
}

// 只获取计数,不加载实体
func CountOnly(ctx context.Context, client *ent.Client) (int, error) {
    return client.User.
        Query().
        Where(user.StatusEQ("active")).
        Count(ctx)
}

// 只检查存在性
func ExistsOnly(ctx context.Context, client *ent.Client, email string) (bool, error) {
    return client.User.
        Query().
        Where(user.EmailEQ(email)).
        Exist(ctx)
}

// 使用Only代替First(当确定只有一个结果时)
func GetSingleResult(ctx context.Context, client *ent.Client, id int) (*ent.User, error) {
    return client.User.
        Query().
        Where(user.IDEQ(id)).
        Only(ctx) // 如果有多个结果会返回错误
}

14.3 连接池配置

package main

import (
    "database/sql"
    "time"

    "entgo.io/ent/dialect"
    entsql "entgo.io/ent/dialect/sql"
    
    "myapp/ent"
    
    _ "github.com/lib/pq"
)

func NewClientWithPool() (*ent.Client, error) {
    // 打开数据库连接
    db, err := sql.Open("postgres", "postgres://localhost:5432/myapp?sslmode=disable")
    if err != nil {
        return nil, err
    }
    
    // 配置连接池
    db.SetMaxOpenConns(100)           // 最大打开连接数
    db.SetMaxIdleConns(10)            // 最大空闲连接数
    db.SetConnMaxLifetime(time.Hour)  // 连接最大存活时间
    db.SetConnMaxIdleTime(time.Minute * 30) // 空闲连接最大存活时间
    
    // 创建Ent驱动
    drv := entsql.OpenDB(dialect.Postgres, db)
    
    // 创建Ent客户端
    client := ent.NewClient(ent.Driver(drv))
    
    return client, nil
}

14.4 缓存策略

package cache

import (
    "context"
    "encoding/json"
    "fmt"
    "time"

    "github.com/go-redis/redis/v8"
    
    "myapp/ent"
)

type CachedUserRepository struct {
    client *ent.Client
    redis  *redis.Client
    ttl    time.Duration
}

func NewCachedUserRepository(client *ent.Client, redis *redis.Client) *CachedUserRepository {
    return &CachedUserRepository{
        client: client,
        redis:  redis,
        ttl:    time.Minute * 15,
    }
}

func (r *CachedUserRepository) cacheKey(id int) string {
    return fmt.Sprintf("user:%d", id)
}

// GetByID 带缓存的获取
func (r *CachedUserRepository) GetByID(ctx context.Context, id int) (*ent.User, error) {
    key := r.cacheKey(id)
    
    // 尝试从缓存获取
    cached, err := r.redis.Get(ctx, key).Result()
    if err == nil {
        var user ent.User
        if err := json.Unmarshal([]byte(cached), &user); err == nil {
            return &user, nil
        }
    }
    
    // 从数据库获取
    user, err := r.client.User.Get(ctx, id)
    if err != nil {
        return nil, err
    }
    
    // 写入缓存
    data, _ := json.Marshal(user)
    r.redis.Set(ctx, key, data, r.ttl)
    
    return user, nil
}

// Update 更新并清除缓存
func (r *CachedUserRepository) Update(ctx context.Context, id int, name string) (*ent.User, error) {
    user, err := r.client.User.
        UpdateOneID(id).
        SetName(name).
        Save(ctx)
    
    if err != nil {
        return nil, err
    }
    
    // 清除缓存
    r.redis.Del(ctx, r.cacheKey(id))
    
    return user, nil
}

// Delete 删除并清除缓存
func (r *CachedUserRepository) Delete(ctx context.Context, id int) error {
    err := r.client.User.DeleteOneID(id).Exec(ctx)
    if err != nil {
        return err
    }
    
    // 清除缓存
    r.redis.Del(ctx, r.cacheKey(id))
    
    return nil
}

// 使用缓存拦截器
func CacheInterceptor(redis *redis.Client, ttl time.Duration) ent.Interceptor {
    return intercept.TraverseFunc(func(ctx context.Context, q intercept.Query) error {
        // 可以在这里实现查询级别的缓存
        return nil
    })
}

十五、多数据库支持

15.1 读写分离

package database

import (
    "context"
    "database/sql"
    "sync/atomic"

    "entgo.io/ent/dialect"
    entsql "entgo.io/ent/dialect/sql"
    
    "myapp/ent"
)

// ReadWriteDriver 读写分离驱动
type ReadWriteDriver struct {
    write   dialect.Driver
    reads   []dialect.Driver
    counter uint64
}

func NewReadWriteDriver(writeDB *sql.DB, readDBs ...*sql.DB) *ReadWriteDriver {
    reads := make([]dialect.Driver, len(readDBs))
    for i, db := range readDBs {
        reads[i] = entsql.OpenDB(dialect.Postgres, db)
    }
    
    return &ReadWriteDriver{
        write: entsql.OpenDB(dialect.Postgres, writeDB),
        reads: reads,




    }
}

// 实现dialect.Driver接口
func (d *ReadWriteDriver) Dialect() string {
    return d.write.Dialect()
}

func (d *ReadWriteDriver) Close() error {
    for _, read := range d.reads {
        read.Close()
    }
    return d.write.Close()
}

// Exec 写操作使用主库
func (d *ReadWriteDriver) Exec(ctx context.Context, query string, args, v interface{}) error {
    return d.write.Exec(ctx, query, args, v)
}

// Query 读操作使用从库(轮询)
func (d *ReadWriteDriver) Query(ctx context.Context, query string, args, v interface{}) error {
    // 检查是否强制使用主库
    if useMaster(ctx) {
        return d.write.Query(ctx, query, args, v)
    }
    
    // 轮询选择从库
    idx := atomic.AddUint64(&d.counter, 1) % uint64(len(d.reads))
    return d.reads[idx].Query(ctx, query, args, v)
}

// Tx 事务使用主库
func (d *ReadWriteDriver) Tx(ctx context.Context) (dialect.Tx, error) {
    return d.write.Tx(ctx)
}

// BeginTx 开始事务
func (d *ReadWriteDriver) BeginTx(ctx context.Context, opts *sql.TxOptions) (dialect.Tx, error) {
    return d.write.(interface {
        BeginTx(context.Context, *sql.TxOptions) (dialect.Tx, error)
    }).BeginTx(ctx, opts)
}

// Context key for forcing master
type masterKey struct{}

func useMaster(ctx context.Context) bool {
    v, _ := ctx.Value(masterKey{}).(bool)
    return v
}

// WithMaster 强制使用主库
func WithMaster(ctx context.Context) context.Context {
    return context.WithValue(ctx, masterKey{}, true)
}

// 使用示例
func SetupReadWriteClient() (*ent.Client, error) {
    // 主库
    masterDB, err := sql.Open("postgres", "postgres://master:5432/myapp")
    if err != nil {
        return nil, err
    }
    
    // 从库
    slave1DB, err := sql.Open("postgres", "postgres://slave1:5432/myapp")
    if err != nil {
        return nil, err
    }
    
    slave2DB, err := sql.Open("postgres", "postgres://slave2:5432/myapp")
    if err != nil {
        return nil, err
    }
    
    // 创建读写分离驱动
    drv := NewReadWriteDriver(masterDB, slave1DB, slave2DB)
    
    return ent.NewClient(ent.Driver(drv)), nil
}

// 业务代码使用
func QueryExample(ctx context.Context, client *ent.Client) {
    // 普通查询 - 使用从库
    users, _ := client.User.Query().All(ctx)
    
    // 强制使用主库查询(例如刚写入后立即读取)
    user, _ := client.User.Query().
        Where(user.IDEQ(1)).
        Only(WithMaster(ctx))
}

15.2 多租户数据库

package multitenancy

import (
    "context"
    "fmt"
    "sync"

    "entgo.io/ent/dialect"
    entsql "entgo.io/ent/dialect/sql"
    
    "myapp/ent"
)

// TenantManager 租户管理器
type TenantManager struct {
    clients map[string]*ent.Client
    mu      sync.RWMutex
    config  DatabaseConfig
}

type DatabaseConfig struct {
    Host     string
    Port     int
    User     string
    Password string
}

func NewTenantManager(config DatabaseConfig) *TenantManager {
    return &TenantManager{
        clients: make(map[string]*ent.Client),
        config:  config,
    }
}

// GetClient 获取租户的数据库客户端
func (m *TenantManager) GetClient(tenantID string) (*ent.Client, error) {
    m.mu.RLock()
    client, exists := m.clients[tenantID]
    m.mu.RUnlock()
    
    if exists {
        return client, nil
    }
    
    // 创建新连接
    m.mu.Lock()
    defer m.mu.Unlock()
    
    // 双重检查
    if client, exists := m.clients[tenantID]; exists {
        return client, nil
    }
    
    // 每个租户使用独立的数据库
    dsn := fmt.Sprintf(
        "postgres://%s:%s@%s:%d/tenant_%s?sslmode=disable",
        m.config.User,
        m.config.Password,
        m.config.Host,
        m.config.Port,
        tenantID,
    )
    
    client, err := ent.Open("postgres", dsn)
    if err != nil {
        return nil, err
    }
    
    // 自动迁移
    if err := client.Schema.Create(context.Background()); err != nil {
        client.Close()
        return nil, err
    }
    
    m.clients[tenantID] = client
    return client, nil
}

// CloseAll 关闭所有连接
func (m *TenantManager) CloseAll() {
    m.mu.Lock()
    defer m.mu.Unlock()
    
    for _, client := range m.clients {
        client.Close()
    }
}

// 中间件:从请求中获取租户并注入客户端
func TenantMiddleware(manager *TenantManager) func(http.Handler) http.Handler {
    return func(next http.Handler) http.Handler {
        return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
            tenantID := r.Header.Get("X-Tenant-ID")
            if tenantID == "" {
                http.Error(w, "Tenant ID required", http.StatusBadRequest)
                return
            }
            
            client, err := manager.GetClient(tenantID)
            if err != nil {
                http.Error(w, "Failed to get tenant database", http.StatusInternalServerError)
                return
            }
            
            ctx := context.WithValue(r.Context(), "ent_client", client)
            next.ServeHTTP(w, r.WithContext(ctx))
        })
    }
}

// 从context获取客户端
func ClientFromContext(ctx context.Context) *ent.Client {
    return ctx.Value("ent_client").(*ent.Client)
}

15.3 Schema级别多租户

package schema

import (
    "context"
    "fmt"

    "entgo.io/ent/dialect"
    "entgo.io/ent/dialect/sql"
    
    "myapp/ent"
)

// SchemaDriver 支持Schema切换的驱动
type SchemaDriver struct {
    dialect.Driver
    schema string
}

func NewSchemaDriver(drv dialect.Driver, schema string) *SchemaDriver {
    return &SchemaDriver{
        Driver: drv,
        schema: schema,
    }
}

func (d *SchemaDriver) Exec(ctx context.Context, query string, args, v interface{}) error {
    // 设置search_path
    setSchema := fmt.Sprintf("SET search_path TO %s", d.schema)
    if err := d.Driver.Exec(ctx, setSchema, []interface{}{}, nil); err != nil {
        return err
    }
    return d.Driver.Exec(ctx, query, args, v)
}

func (d *SchemaDriver) Query(ctx context.Context, query string, args, v interface{}) error {
    setSchema := fmt.Sprintf("SET search_path TO %s", d.schema)
    if err := d.Driver.Exec(ctx, setSchema, []interface{}{}, nil); err != nil {
        return err
    }
    return d.Driver.Query(ctx, query, args, v)
}

// 为租户创建Schema
func CreateTenantSchema(ctx context.Context, db *sql.DB, tenantID string) error {
    schemaName := fmt.Sprintf("tenant_%s", tenantID)
    _, err := db.ExecContext(ctx, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schemaName))
    return err
}

// 获取租户客户端
func GetTenantClient(baseClient *ent.Client, tenantID string) *ent.Client {
    schemaName := fmt.Sprintf("tenant_%s", tenantID)
    drv := NewSchemaDriver(baseClient.Driver(), schemaName)
    return ent.NewClient(ent.Driver(drv))
}

十六、完整项目结构

myapp/
├── cmd/
│   └── server/
│       └── main.go              # 应用入口
├── ent/
│   ├── schema/                  # Schema定义
│   │   ├── user.go
│   │   ├── post.go
│   │   ├── comment.go
│   │   └── mixin/
│   │       ├── time.go
│   │       └── soft_delete.go
│   ├── generate.go              # 代码生成配置
│   ├── entc.go                  # Ent配置
│   ├── client.go                # 生成的客户端
│   ├── user.go                  # 生成的User实体
│   ├── post.go                  # 生成的Post实体
│   ├── migrate/
│   │   ├── main.go              # 迁移脚本
│   │   └── migrations/          # 版本化迁移文件
│   ├── hook/                    # 生成的Hooks
│   ├── intercept/               # 生成的拦截器
│   └── privacy/                 # 生成的隐私规则
├── internal/
│   ├── config/
│   │   └── config.go            # 配置管理
│   ├── database/
│   │   └── database.go          # 数据库连接
│   ├── repository/              # 数据访问层
│   │   ├── user.go
│   │   └── post.go
│   ├── service/                 # 业务逻辑层
│   │   ├── user.go
│   │   └── post.go
│   └── handler/                 # HTTP处理器
│       ├── user.go
│       └── post.go
├── pkg/
│   ├── pagination/
│   │   └── pagination.go        # 分页工具
│   └── validator/
│       └── validator.go         # 验证工具
├── api/
│   └── proto/                   # gRPC Proto文件
│       └── user.proto
├── graph/                       # GraphQL
│   ├── resolver/
│   │   └── resolver.go
│   └── generated/
├── go.mod
├── go.sum
└── Makefile

16.1 Makefile示例

.PHONY: generate migrate test run

# 生成Ent代码
generate:
	go generate ./ent

# 创建新的Schema
new-schema:
	go run -mod=mod entgo.io/ent/cmd/ent new $(name)

# 生成迁移文件
migrate-diff:
	go run -mod=mod ent/migrate/main.go $(name)

# 应用迁移
migrate-apply:
	atlas migrate apply \
		--dir "file://ent/migrate/migrations" \
		--url "$(DATABASE_URL)"

# 运行测试
test:
	go test -v ./...

# 运行测试(带覆盖率)
test-coverage:
	go test -v -coverprofile=coverage.out ./...
	go tool cover -html=coverage.out

# 运行应用
run:
	go run cmd/server/main.go

# 格式化代码
fmt:
	go fmt ./...
	goimports -w .

# 代码检查
lint:
	golangci-lint run

# 构建
build:
	go build -o bin/server cmd/server/main.go

# 生成Proto
proto:
	protoc --go_out=. --go-grpc_out=. api/proto/*.proto

16.2 完整的Main函数示例

// cmd/server/main.go
package main

import (
    "context"
    "log"
    "net/http"
    "os"
    "os/signal"
    "syscall"
    "time"

    "github.com/go-chi/chi/v5"
    "github.com/go-chi/chi/v5/middleware"
    
    "myapp/ent"
    "myapp/internal/config"
    "myapp/internal/database"
    "myapp/internal/handler"
    "myapp/internal/repository"
    "myapp/internal/service"
)

func main() {
    // 加载配置
    cfg, err := config.Load()
    if err != nil {
        log.Fatalf("Failed to load config: %v", err)
    }
    
    // 连接数据库
    client, err := database.NewClient(cfg.Database)
    if err != nil {
        log.Fatalf("Failed to connect database: %v", err)
    }
    defer client.Close()
    
    // 运行迁移
    if err := client.Schema.Create(context.Background()); err != nil {
        log.Fatalf("Failed to run migrations: %v", err)
    }
    
    // 初始化层
    userRepo := repository.NewUserRepository(client)
    userService := service.NewUserService(userRepo)
    userHandler := handler.NewUserHandler(userService)
    
    // 设置路由
    r := chi.NewRouter()
    
    // 中间件
    r.Use(middleware.Logger)
    r.Use(middleware.Recoverer)
    r.Use(middleware.Timeout(60 * time.Second))
    
    // 路由
    r.Route("/api/v1", func(r chi.Router) {
        r.Route("/users", func(r chi.Router) {
            r.Get("/", userHandler.List)
            r.Post("/", userHandler.Create)
            r.Get("/{id}", userHandler.Get)
            r.Put("/{id}", userHandler.Update)
            r.Delete("/{id}", userHandler.Delete)
        })
    })
    
    // 健康检查
    r.Get("/health", func(w http.ResponseWriter, r *http.Request) {
        w.WriteHeader(http.StatusOK)
        w.Write([]byte("OK"))
    })
    
    // 创建服务器
    server := &http.Server{
        Addr:         ":" + cfg.Server.Port,
        Handler:      r,
        ReadTimeout:  15 * time.Second,
        WriteTimeout: 15 * time.Second,
        IdleTimeout:  60 * time.Second,
    }
    
    // 优雅关闭
    go func() {
        sigChan := make(chan os.Signal, 1)
        signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
        <-sigChan

        log.Println("Shutting down server...")

        ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
        defer cancel()

        if err := server.Shutdown(ctx); err != nil {
            log.Printf("Server shutdown error: %v", err)
        }

        client.Close()
        log.Println("Server stopped")
    }()

    // 启动服务器
    log.Printf("Server running at %s", server.Addr)
    if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
        log.Fatalf("Server error: %v", err)
    }
}

十七、最佳实践

17.1 Schema设计最佳实践

使用Mixin复用通用字段

// 创建基础Mixin
type BaseMixin struct {
    mixin.Schema
}

func (BaseMixin) Fields() []ent.Field {
    return []ent.Field{
        field.Time("created_at").
            Immutable().
            Default(time.Now),
        field.Time("updated_at").
            Default(time.Now).
            UpdateDefault(time.Now),
    }
}

// 在所有实体中使用
type User struct {
    ent.Schema
}

func (User) Mixin() []ent.Mixin {
    return []ent.Mixin{BaseMixin{}}
}

合理使用索引

// 为经常查询的字段添加索引
func (User) Indexes() []ent.Index {
    return []ent.Index{
        index.Fields("email").Unique(),
        index.Fields("status", "created_at"),
        index.Fields("tenant_id", "email").Unique(),
    }
}

字段验证

func (User) Fields() []ent.Field {
    return []ent.Field{
        // 使用内置验证
        field.String("email").
            NotEmpty().
            Match(regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)),

        // 自定义验证
        field.Int("age").
            Validate(func(age int) error {
                if age < 0 || age > 150 {
                    return errors.New("age must be between 0 and 150")
                }
                return nil
            }),
    }
}

17.2 查询优化最佳实践

避免N+1问题

// ❌ 不好的做法
users, _ := client.User.Query().All(ctx)
for _, u := range users {
    posts, _ := u.QueryPosts().All(ctx) // N+1查询
}

// ✅ 好的做法
users, _ := client.User.Query().
    WithPosts().
    All(ctx)
for _, u := range users {
    posts := u.Edges.Posts // 已预加载
}

只查询需要的字段

// ❌ 不好的做法 - 查询所有字段
users, _ := client.User.Query().All(ctx)

// ✅ 好的做法 - 只查询需要的字段
var results []struct {
    ID   int
    Name string
}
client.User.Query().
    Select(user.FieldID, user.FieldName).
    Scan(ctx, &results)

使用分页

// 始终对大量数据使用分页
users, _ := client.User.Query().
    Offset((page - 1) * pageSize).
    Limit(pageSize).
    All(ctx)

17.3 事务使用最佳实践

保持事务简短

// ✅ 好的做法 - 事务只包含必要的数据库操作
err := WithTx(ctx, client, func(tx *ent.Tx) error {
    user, err := tx.User.Create().SetName("用户").Save(ctx)
    if err != nil {
        return err
    }

    _, err = tx.Profile.Create().SetUserID(user.ID).Save(ctx)
    return err
})

// ❌ 不好的做法 - 事务包含非数据库操作
err := WithTx(ctx, client, func(tx *ent.Tx) error {
    user, err := tx.User.Create().SetName("用户").Save(ctx)
    if err != nil {
        return err
    }

    // 不应该在事务中执行外部API调用
    sendEmail(user.Email) // ❌

    return nil
})

17.4 错误处理最佳实践

// 使用Ent提供的错误检查函数
user, err := client.User.Get(ctx, id)
if err != nil {
    if ent.IsNotFound(err) {
        // 处理未找到错误
        return nil, fmt.Errorf("user not found")
    }
    if ent.IsConstraintError(err) {
        // 处理约束错误(如唯一索引冲突)
        return nil, fmt.Errorf("user already exists")
    }
    // 其他错误
    return nil, err
}

17.5 性能监控

// 添加查询性能监控Hook
client.Use(func(next ent.Mutator) ent.Mutator {
    return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) {
        start := time.Now()
        v, err := next.Mutate(ctx, m)
        duration := time.Since(start)

        // 记录慢查询
        if duration > 100*time.Millisecond {
            log.Printf("Slow query detected: %s took %v", m.Type(), duration)
        }

        return v, err
    })
})

十八、常见问题

18.1 如何处理软删除?

参见第 10.5 节的软删除实现。

18.2 如何实现乐观锁?

参见第 10.6 节的乐观锁实现。

18.3 如何实现多租户?

参见第 15.2 节的多租户数据库实现。

18.4 如何处理大量数据的迁移?

// 分批处理
func MigrateLargeData(ctx context.Context, client *ent.Client) error {
    batchSize := 1000
    var lastID int

    for {
        users, err := client.User.Query().
            Where(user.IDGT(lastID)).
            Order(ent.Asc(user.FieldID)).
            Limit(batchSize).
            All(ctx)

        if err != nil {
            return err
        }

        if len(users) == 0 {
            break
        }

        // 处理批次数据
        for _, u := range users {
            // 执行迁移逻辑
        }

        lastID = users[len(users)-1].ID
    }

    return nil
}

18.5 如何调试生成的SQL?

// 启用SQL日志
import (
    "entgo.io/ent/dialect/sql"
    entsql "entgo.io/ent/dialect/sql"
)

// 创建带日志的驱动
drv := entsql.OpenDB(dialect.Postgres, db)
client := ent.NewClient(ent.Driver(entsql.Driver(drv)))
client.Debug() // 启用调试模式,会打印所有SQL查询

18.6 如何处理JSON字段?

type Metadata struct {
    Tags    []string          `json:"tags"`
    Settings map[string]string `json:"settings"`
}

// Schema定义
func (User) Fields() []ent.Field {
    return []ent.Field{
        field.JSON("metadata", &Metadata{}).
            Optional(),
    }
}

// 使用
user, _ := client.User.Create().
    SetMetadata(&Metadata{
        Tags: []string{"tag1", "tag2"},
        Settings: map[string]string{"theme": "dark"},
    }).
    Save(ctx)

// 查询
metadata := user.Metadata
tags := metadata.Tags

十九、总结

Ent ORM 是一个功能强大、类型安全的 Go 语言 ORM 框架,具有以下特点:

核心优势

  • 类型安全:100% 静态类型和显式 API
  • 代码生成:通过代码生成避免运行时反射,性能更好
  • 图遍历:轻松处理复杂的关系查询
  • 扩展性强:支持 Hooks、Privacy、Interceptors 等多种扩展机制
  • 现代化:支持 GraphQL、gRPC 等现代开发技术

适用场景

  • 大型应用和微服务
  • 需要复杂数据模型和关系的项目
  • 对类型安全有严格要求的项目
  • 需要代码优先(Code First)方法的项目

学习路径建议

  1. 从基础的 Schema 定义和 CRUD 操作开始
  2. 掌握关系定义和查询
  3. 学习 Hooks 和 Interceptors 进行扩展
  4. 了解高级特性如事务、Privacy、乐观锁等
  5. 根据项目需求集成 GraphQL 或 gRPC
  6. 学习性能优化和最佳实践

参考资源

希望本指南能帮助你快速掌握 Ent ORM 的使用!