本文由claude code生成。作为备忘录和分享,欢迎指正和补充。
一、Ent简介
Ent是一个简单但功能强大的实体框架,用于Go语言,使构建和维护具有大型数据模型的应用程序变得容易。 Ent是由Meta(Facebook)开源团队构建的ORM框架,提供了一个API,用于将任何数据库模式建模为Go对象。
核心特性
Ent的主要特性包括:
- Schema As Code:将任何数据库模式建模为Go对象
- 轻松遍历任何图结构:运行查询、聚合并轻松遍历任何图结构
- 静态类型和显式API:100%静态类型和显式API,通过代码生成实现
- 多存储驱动支持:支持MySQL、PostgreSQL、SQLite和Gremlin
- 易于扩展:使用Go模板简单扩展和自定义
Ent是一个相当新的ORM,使用代码优先的方法,在Go代码中定义模式。Ent之所以流行,是因为它能够优雅地处理复杂的数据模型和关系。
二、快速开始
2.1 安装Ent
# 初始化Go模块
go mod init myapp
# 安装ent CLI工具
go install entgo.io/ent/cmd/ent@latest
# 或者使用go get
go get entgo.io/ent/cmd/ent
2.2 创建第一个Schema
# 创建User实体
ent new User
这会在 ent/schema/ 目录下生成一个基础的schema文件:
package schema
import (
"entgo.io/ent"
"entgo.io/ent/schema/field"
)
// User holds the schema definition for the User entity.
type User struct {
ent.Schema
}
// Fields of the User.
func (User) Fields() []ent.Field {
return []ent.Field{
field.String("name").
NotEmpty(),
field.Int("age").
Positive(),
field.String("email").
Unique(),
field.Time("created_at").
Default(time.Now),
}
}
// Edges of the User.
func (User) Edges() []ent.Edge {
return nil
}
2.3 生成代码
go generate ./ent
2.4 连接数据库并使用
package main
import (
"context"
"log"
"myapp/ent"
_ "github.com/lib/pq" // PostgreSQL
// _ "github.com/go-sql-driver/mysql" // MySQL
// _ "github.com/mattn/go-sqlite3" // SQLite
)
func main() {
// 连接PostgreSQL
client, err := ent.Open("postgres",
"host=localhost port=5432 user=postgres dbname=myapp password=password sslmode=disable")
if err != nil {
log.Fatalf("failed opening connection to postgres: %v", err)
}
defer client.Close()
// 运行自动迁移
if err := client.Schema.Create(context.Background()); err != nil {
log.Fatalf("failed creating schema resources: %v", err)
}
ctx := context.Background()
// 创建用户
user, err := client.User.
Create().
SetName("张三").
SetAge(30).
SetEmail("zhangsan@example.com").
Save(ctx)
if err != nil {
log.Fatalf("failed creating user: %v", err)
}
log.Println("user was created: ", user)
}
2.5 使用SQLite3(非CGO驱动)
📋 使用场景
为什么使用非CGO的SQLite驱动?
- ✅ 跨平台编译 - 不需要CGO,可以轻松交叉编译
- ✅ 部署简单 - 不依赖SQLite C库
- ✅ 开发便捷 - 不需要配置CGO环境
- ✅ 完全兼容 - 功能与CGO版本完全一致
适用场景:
- 嵌入式应用或桌面应用
- 开发和测试环境
- 轻量级数据存储
- 需要跨平台编译的项目
安装和配置
# 安装非CGO的SQLite驱动
go get github.com/ncruces/go-sqlite3
基本使用
package main
import (
"context"
"log"
"myapp/ent"
_ "github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
)
func main() {
// 使用文件模式连接SQLite
// cache=shared: 允许多个连接共享缓存
// _pragma: 设置SQLite参数
dsn := "file:./data.db?cache=shared" +
"&_pragma=foreign_keys(1)" + // 启用外键约束
"&_pragma=journal_mode(WAL)" + // 使用WAL模式提升并发性能
"&_pragma=synchronous(NORMAL)" + // 平衡性能和安全性
"&_pragma=busy_timeout(10000)" + // 10秒超时
"&_pragma=cache_size(-64000)" // 64MB缓存
client, err := ent.Open("sqlite3", dsn)
if err != nil {
log.Fatalf("failed opening connection to sqlite: %v", err)
}
defer client.Close()
// 运行自动迁移
if err := client.Schema.Create(context.Background()); err != nil {
log.Fatalf("failed creating schema resources: %v", err)
}
ctx := context.Background()
// 创建用户
user, err := client.User.
Create().
SetName("张三").
SetAge(30).
SetEmail("zhangsan@example.com").
Save(ctx)
if err != nil {
log.Fatalf("failed creating user: %v", err)
}
log.Println("user was created: ", user)
}
内存模式
// 使用内存数据库(测试场景)
func NewTestClient() *ent.Client {
client, err := ent.Open("sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
if err != nil {
log.Fatalf("failed opening in-memory db: %v", err)
}
// 运行迁移
if err := client.Schema.Create(context.Background()); err != nil {
log.Fatalf("failed creating schema resources: %v", err)
}
return client
}
SQLite优化配置
package database
import (
"database/sql"
"time"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
"myapp/ent"
)
func NewSQLiteClient(dbPath string) (*ent.Client, error) {
// 构建DSN
dsn := "file:" + dbPath +
"?cache=shared" +
"&_pragma=foreign_keys(1)" +
"&_pragma=journal_mode(WAL)" +
"&_pragma=synchronous(NORMAL)" +
"&_pragma=busy_timeout(10000)" +
"&_pragma=cache_size(-64000)" +
"&_pragma=temp_store(MEMORY)" +
"&_pragma=mmap_size(268435456)" // 256MB mmap
// 打开数据库
db, err := sql.Open("sqlite3", dsn)
if err != nil {
return nil, err
}
// 配置连接池(SQLite建议单连接)
db.SetMaxOpenConns(1)
db.SetMaxIdleConns(1)
db.SetConnMaxLifetime(time.Hour)
// 创建Ent客户端
drv := entsql.OpenDB(dialect.SQLite, db)
client := ent.NewClient(ent.Driver(drv))
return client, nil
}
使用lib-x/entsqlite
// 另一个流行的非CGO SQLite驱动
import _ "github.com/lib-x/entsqlite"
func main() {
dsn := "file:./data.db?cache=shared" +
"&_pragma=foreign_keys(1)" +
"&_pragma=journal_mode(WAL)" +
"&_pragma=synchronous(NORMAL)" +
"&_pragma=busy_timeout(10000)"
client, err := ent.Open("sqlite3", dsn)
if err != nil {
log.Fatalf("failed opening connection: %v", err)
}
defer client.Close()
// 使用client...
}
SQLite特性说明
WAL模式(Write-Ahead Logging):
- 提升并发读写性能
- 允许读操作和写操作并行
- 生产环境推荐使用
同步模式:
FULL: 最安全,性能最慢NORMAL: 平衡(推荐)OFF: 最快,但可能丢数据
忙等待超时:
- 当数据库被锁定时等待的时间
- 10000毫秒(10秒)对大多数应用合适
缓存大小:
- 负数表示KB,例如 -64000 = 64MB
- 增加缓存可以提升性能
三、Schema定义详解
3.1 字段类型(Fields)
Ent支持多种字段类型,每种类型都有丰富的配置选项:
import (
"time"
"regexp"
"entgo.io/ent"
"entgo.io/ent/schema/field"
)
type User struct {
ent.Schema
}
func (User) Fields() []ent.Field {
return []ent.Field{
// 基础类型
field.Int("id"),
field.Int64("big_id"),
field.Float("score"),
field.Bool("active"),
field.String("name"),
field.Bytes("avatar"),
field.Time("created_at"),
field.UUID("uuid", uuid.UUID{}),
field.JSON("metadata", map[string]interface{}{}),
// 枚举类型
field.Enum("status").
Values("pending", "active", "inactive", "deleted"),
// 带验证的字段
field.String("email").
Match(regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)).
Unique(),
// 可选字段(允许NULL)
field.String("nickname").
Optional().
Nillable(),
// 带默认值
field.Time("created_at").
Default(time.Now).
Immutable(), // 创建后不可修改
// 自定义校验
field.Int("age").
Positive().
Max(150),
// 敏感字段(不会在日志中显示)
field.String("password").
Sensitive(),
// 带注释
field.String("description").
Comment("用户描述信息"),
// 存储键名映射
field.String("phone_number").
StorageKey("phone"),
// 结构体嵌入
field.JSON("address", &Address{}).
Optional(),
}
}
type Address struct {
Street string `json:"street"`
City string `json:"city"`
Country string `json:"country"`
}
3.2 关系定义(Edges)
Edges代表Ent中的关系(一对一、一对多、多对多)。
package schema
import (
"entgo.io/ent"
"entgo.io/ent/schema/edge"
)
// User schema
type User struct {
ent.Schema
}
func (User) Edges() []ent.Edge {
return []ent.Edge{
// 一对多:一个用户有多篇文章
edge.To("posts", Post.Type),
// 一对一:一个用户有一个档案
edge.To("profile", Profile.Type).
Unique(),
// 多对多:用户可以属于多个群组
edge.To("groups", Group.Type),
// 自引用:用户的朋友关系
edge.To("friends", User.Type),
// 带边缘属性的多对多关系
edge.To("following", User.Type).
Through("user_following", UserFollowing.Type),
}
}
// Post schema
type Post struct {
ent.Schema
}
func (Post) Edges() []ent.Edge {
return []ent.Edge{
// 反向边:文章属于一个用户
edge.From("author", User.Type).
Ref("posts").
Unique().
Required(),
// 多对多:文章可以有多个标签
edge.To("tags", Tag.Type),
}
}
// Profile schema (一对一示例)
type Profile struct {
ent.Schema
}
func (Profile) Edges() []ent.Edge {
return []ent.Edge{
edge.From("user", User.Type).
Ref("profile").
Unique().
Required(),
}
}
// Group schema
type Group struct {
ent.Schema
}
func (Group) Edges() []ent.Edge {
return []ent.Edge{
edge.From("members", User.Type).
Ref("groups"),
}
}
// UserFollowing - 用于存储关注关系的边缘属性
type UserFollowing struct {
ent.Schema
}
func (UserFollowing) Fields() []ent.Field {
return []ent.Field{
field.Time("followed_at").
Default(time.Now),
}
}
func (UserFollowing) Edges() []ent.Edge {
return []ent.Edge{
edge.To("follower", User.Type).
Unique().
Required(),
edge.To("following", User.Type).
Unique().
Required(),
}
}
3.3 索引定义(Indexes)
package schema
import (
"entgo.io/ent"
"entgo.io/ent/schema/index"
)
type User struct {
ent.Schema
}
func (User) Indexes() []ent.Index {
return []ent.Index{
// 单字段索引
index.Fields("email"),
// 复合索引
index.Fields("name", "age"),
// 唯一索引
index.Fields("username").
Unique(),
// 复合唯一索引
index.Fields("tenant_id", "email").
Unique(),
// 边和字段的联合索引
index.Fields("name").
Edges("parent"),
}
}
3.4 Mixin(混入)
Mixin是一个接口,允许创建可重用的schema片段,可以注入到另一个schema中。Mixin可以是一组字段、边或钩子。
package schema
import (
"time"
"entgo.io/ent"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/mixin"
)
// TimeMixin 提供创建时间和更新时间字段
type TimeMixin struct {
mixin.Schema
}
func (TimeMixin) Fields() []ent.Field {
return []ent.Field{
field.Time("created_at").
Immutable().
Default(time.Now),
field.Time("updated_at").
Default(time.Now).
UpdateDefault(time.Now),
}
}
// SoftDeleteMixin 提供软删除功能
type SoftDeleteMixin struct {
mixin.Schema
}
func (SoftDeleteMixin) Fields() []ent.Field {
return []ent.Field{
field.Time("deleted_at").
Optional().
Nillable(),
}
}
// TenantMixin 提供多租户支持
type TenantMixin struct {
mixin.Schema
}
func (TenantMixin) Fields() []ent.Field {
return []ent.Field{
field.Int("tenant_id").
Immutable(),
}
}
// User使用Mixin
type User struct {
ent.Schema
}
func (User) Mixin() []ent.Mixin {
return []ent.Mixin{
TimeMixin{},
SoftDeleteMixin{},
TenantMixin{},
}
}
func (User) Fields() []ent.Field {
return []ent.Field{
field.String("name"),
field.String("email").Unique(),
}
}
四、CRUD操作详解
4.1 创建操作(Create)
func CreateExamples(ctx context.Context, client *ent.Client) {
// 创建单个实体
user, err := client.User.
Create().
SetName("张三").
SetEmail("zhangsan@example.com").
SetAge(25).
Save(ctx)
// 创建并返回ID
id, err := client.User.
Create().
SetName("李四").
SetEmail("lisi@example.com").
SaveX(ctx).ID // SaveX会panic如果出错
// 批量创建
users, err := client.User.
CreateBulk(
client.User.Create().SetName("用户1").SetEmail("user1@example.com"),
client.User.Create().SetName("用户2").SetEmail("user2@example.com"),
client.User.Create().SetName("用户3").SetEmail("user3@example.com"),
).
Save(ctx)
// 创建并关联关系
post, err := client.Post.
Create().
SetTitle("我的第一篇文章").
SetContent("内容...").
SetAuthor(user). // 关联用户
AddTags(tag1, tag2). // 添加多个标签
Save(ctx)
// Upsert(存在则更新,不存在则创建)
err = client.User.
Create().
SetEmail("zhangsan@example.com").
SetName("张三").
OnConflict(
sql.ConflictColumns("email"),
).
UpdateNewValues().
Exec(ctx)
// Upsert批量
err = client.User.
CreateBulk(builders...).
OnConflict(
sql.ConflictColumns("email"),
).
UpdateNewValues().
Exec(ctx)
}
4.2 查询操作(Query)
func QueryExamples(ctx context.Context, client *ent.Client) {
// 查询单个实体 - 通过ID
user, err := client.User.Get(ctx, 1)
// 查询单个实体 - 通过条件
user, err = client.User.
Query().
Where(user.Email("zhangsan@example.com")).
Only(ctx) // 只返回一个,多于一个会报错
// First - 返回第一个
user, err = client.User.
Query().
Where(user.AgeGT(18)).
First(ctx)
// 查询多个实体
users, err := client.User.
Query().
Where(
user.Or(
user.AgeGT(18),
user.NameContains("张"),
),
).
All(ctx)
// 条件组合 - 复杂查询
users, err = client.User.
Query().
Where(
user.And(
user.AgeGTE(18),
user.AgeLTE(60),
user.Or(
user.StatusEQ("active"),
user.StatusEQ("pending"),
),
user.Not(
user.NameHasPrefix("test_"),
),
),
).
All(ctx)
// 排序
users, err = client.User.
Query().
Order(ent.Desc(user.FieldCreatedAt)).
Order(ent.Asc(user.FieldName)).
All(ctx)
// 分页
users, err = client.User.
Query().
Offset(0).
Limit(10).
All(ctx)
// 选择特定字段
names, err := client.User.
Query().
Select(user.FieldName, user.FieldEmail).
Strings(ctx)
// 去重
names, err = client.User.
Query().
Unique(true).
Select(user.FieldName).
Strings(ctx)
// 计数
count, err := client.User.
Query().
Where(user.StatusEQ("active")).
Count(ctx)
// 判断是否存在
exists, err := client.User.
Query().
Where(user.Email("test@example.com")).
Exist(ctx)
// 聚合查询
var result []struct {
Status string
Count int
SumAge int
AvgAge float64
MaxAge int
MinAge int
}
err = client.User.
Query().
GroupBy(user.FieldStatus).
Aggregate(
ent.Count(),
ent.Sum(user.FieldAge),
ent.Mean(user.FieldAge),
ent.Max(user.FieldAge),
ent.Min(user.FieldAge),
).
Scan(ctx, &result)
// Having子句
err = client.User.
Query().
GroupBy(user.FieldStatus).
Aggregate(ent.Count()).
Having(
sql.GT(sql.Count("*"), 10),
).
Scan(ctx, &result)
}
4.3 关系查询(Edge Query)
func EdgeQueryExamples(ctx context.Context, client *ent.Client) {
// 查询用户的所有文章
user, _ := client.User.Get(ctx, 1)
posts, err := user.QueryPosts().All(ctx)
// 带条件的关系查询
posts, err = user.
QueryPosts().
Where(post.StatusEQ("published")).
Order(ent.Desc(post.FieldCreatedAt)).
Limit(10).
All(ctx)
// 预加载(Eager Loading)- 避免N+1问题
users, err := client.User.
Query().
WithPosts(). // 预加载文章
WithProfile(). // 预加载档案
WithGroups(func(q *ent.GroupQuery) {
// 可以对预加载的关系添加条件
q.Where(group.ActiveEQ(true))
q.Order(ent.Asc(group.FieldName))
}).
All(ctx)
// 访问预加载的数据
for _, u := range users {
fmt.Println("User:", u.Name)
for _, p := range u.Edges.Posts {
fmt.Println(" Post:", p.Title)
}
if u.Edges.Profile != nil {
fmt.Println(" Bio:", u.Edges.Profile.Bio)
}
}
// 深层嵌套预加载
users, err = client.User.
Query().
WithPosts(func(q *ent.PostQuery) {
q.WithTags()
q.WithComments(func(cq *ent.CommentQuery) {
cq.WithAuthor()
})
}).
All(ctx)
// 通过关系查询(反向查询)
// 查询发表了某篇文章的用户
author, err := client.Post.
Query().
Where(post.IDEQ(1)).
QueryAuthor().
Only(ctx)
// 链式关系查询
// 查询用户朋友的所有文章
posts, err = client.User.
Query().
Where(user.IDEQ(1)).
QueryFriends().
QueryPosts().
Where(post.StatusEQ("published")).
All(ctx)
// Named Edge(命名边加载)
users, err = client.User.
Query().
WithNamedPosts("recent_posts", func(q *ent.PostQuery) {
q.Order(ent.Desc(post.FieldCreatedAt)).Limit(5)
}).
WithNamedPosts("popular_posts", func(q *ent.PostQuery) {
q.Order(ent.Desc(post.FieldViewCount)).Limit(5)
}).
All(ctx)
// 访问命名边
for _, u := range users {
recentPosts, _ := u.NamedPosts("recent_posts")
popularPosts, _ := u.NamedPosts("popular_posts")
}
}
4.4 更新操作(Update)
func UpdateExamples(ctx context.Context, client *ent.Client) {
// 更新单个实体(通过ID)
user, err := client.User.
UpdateOneID(1).
SetName("新名字").
SetAge(26).
Save(ctx)
// 更新单个实体(通过实体对象)
user, _ = client.User.Get(ctx, 1)
user, err = user.
Update().
SetName("新名字").
AddAge(1). // 年龄+1
Save(ctx)
// 条件更新(批量)
affected, err := client.User.
Update().
Where(user.StatusEQ("inactive")).
SetStatus("deleted").
Save(ctx)
fmt.Printf("更新了 %d 条记录\n", affected)
// 清除可选字段
user, err = client.User.
UpdateOneID(1).
ClearNickname(). // 设置为NULL
ClearDeletedAt().
Save(ctx)
// 更新关系
user, err = client.User.
UpdateOneID(1).
AddPostIDs(1, 2, 3). // 添加文章关联
RemovePostIDs(4, 5). // 移除文章关联
ClearPosts(). // 清除所有文章关联
SetProfileID(1). // 设置档案
Save(ctx)
// 添加/移除多对多关系
user, err = client.User.
UpdateOneID(1).
AddGroups(group1, group2).
RemoveGroups(group3).
Save(ctx)
// 数值字段操作
post, err := client.Post.
UpdateOneID(1).
AddViewCount(1). // 浏览量+1
AddLikeCount(1). // 点赞+1
Save(ctx)
// 使用Modifier进行更复杂的更新
err = client.User.
Update().
Where(user.IDEQ(1)).
Modify(func(u *sql.UpdateBuilder) {
u.Set("login_count", sql.Expr("login_count + 1"))
u.Set("last_login_at", sql.Expr("NOW()"))
}).
Exec(ctx)
}
4.5 删除操作(Delete)
func DeleteExamples(ctx context.Context, client *ent.Client) {
// 删除单个实体(通过ID)
err := client.User.DeleteOneID(1).Exec(ctx)
// 删除单个实体(通过实体对象)
user, _ := client.User.Get(ctx, 1)
err = client.User.DeleteOne(user).Exec(ctx)
// 条件删除(批量)
affected, err := client.User.
Delete().
Where(
user.And(
user.StatusEQ("deleted"),
user.DeletedAtLT(time.Now().AddDate(0, -6, 0)), // 6个月前删除的
),
).
Exec(ctx)
fmt.Printf("删除了 %d 条记录\n", affected)
// 清空表(谨慎使用)
affected, err = client.User.Delete().Exec(ctx)
}
五、事务处理
5.1 基本事务
func TransactionExample(ctx context.Context, client *ent.Client) error {
// 开始事务
tx, err := client.Tx(ctx)
if err != nil {
return err
}
// 创建用户
user, err := tx.User.
Create().
SetName("张三").
SetEmail("zhangsan@example.com").
Save(ctx)
if err != nil {
return rollback(tx, err)
}
// 创建档案
_, err = tx.Profile.
Create().
SetBio("这是个人简介").
SetUser(user).
Save(ctx)
if err != nil {
return rollback(tx, err)
}
// 提交事务
return tx.Commit()
}
// 回滚辅助函数
func rollback(tx *ent.Tx, err error) error {
if rerr := tx.Rollback(); rerr != nil {
err = fmt.Errorf("%w: %v", err, rerr)
}
return err
}
5.2 使用WithTx辅助函数
// WithTx 封装事务处理逻辑
func WithTx(ctx context.Context, client *ent.Client, fn func(tx *ent.Tx) error) error {
tx, err := client.Tx(ctx)
if err != nil {
return err
}
defer func() {
if v := recover(); v != nil {
tx.Rollback()
panic(v)
}
}()
if err := fn(tx); err != nil {
if rerr := tx.Rollback(); rerr != nil {
err = fmt.Errorf("%w: rolling back transaction: %v", err, rerr)
}
return err
}
return tx.Commit()
}
// 使用示例
func CreateUserWithProfile(ctx context.Context, client *ent.Client) error {
return WithTx(ctx, client, func(tx *ent.Tx) error {
user, err := tx.User.
Create().
SetName("张三").
SetEmail("zhangsan@example.com").
Save(ctx)
if err != nil {
return err
}
_, err = tx.Profile.
Create().
SetBio("个人简介").
SetUser(user).
Save(ctx)
return err
})
}
5.3 嵌套事务和保存点
func NestedTransactionExample(ctx context.Context, client *ent.Client) error {
tx, err := client.Tx(ctx)
if err != nil {
return err
}
// 创建第一个用户
user1, err := tx.User.Create().SetName("用户1").SetEmail("user1@example.com").Save(ctx)
if err != nil {
return rollback(tx, err)
}
// 创建保存点
savepoint := tx.Savepoint(ctx, "sp1")
// 尝试创建第二个用户
_, err = tx.User.Create().SetName("用户2").SetEmail("user1@example.com").Save(ctx) // 邮箱重复
if err != nil {
// 回滚到保存点,而不是整个事务
if rerr := savepoint.Rollback(); rerr != nil {
return rollback(tx, fmt.Errorf("%w: %v", err, rerr))
}
// 继续执行其他操作...
}
return tx.Commit()
}
六、Hooks(钩子)
Hooks允许你在实体创建、更新或删除之前或之后执行自定义逻辑。
6.1 Schema级别的Hooks
package schema
import (
"context"
"fmt"
"time"
"entgo.io/ent"
"entgo.io/ent/schema/field"
gen "myapp/ent"
"myapp/ent/hook"
)
type User struct {
ent.Schema
}
func (User) Hooks() []ent.Hook {
return []ent.Hook{
// 创建前的钩子
hook.On(
func(next ent.Mutator) ent.Mutator {
return hook.UserFunc(func(ctx context.Context, m *gen.UserMutation) (ent.Value, error) {
// 设置默认值
if _, exists := m.Status(); !exists {
m.SetStatus("pending")
}
// 记录日志
name, _ := m.Name()
fmt.Printf("Creating user: %s\n", name)
return next.Mutate(ctx, m)
})
},
ent.OpCreate,
),
// 更新前的钩子
hook.On(
func(next ent.Mutator) ent.Mutator {
return hook.UserFunc(func(ctx context.Context, m *gen.UserMutation) (ent.Value, error) {
// 自动更新updated_at字段
m.SetUpdatedAt(time.Now())
return next.Mutate(ctx, m)
})
},
ent.OpUpdate|ent.OpUpdateOne,
),
// 删除前的钩子
hook.On(
func(next ent.Mutator) ent.Mutator {
return hook.UserFunc(func(ctx context.Context, m *gen.UserMutation) (ent.Value, error) {
// 获取要删除的用户ID
id, exists := m.ID()
if exists {
fmt.Printf("Deleting user: %d\n", id)
}
return next.Mutate(ctx, m)
})
},
ent.OpDelete|ent
),
}
}
6.2 运行时Hooks(全局Hooks)
package main
import (
"context"
"fmt"
"log"
"time"
"myapp/ent"
"myapp/ent/hook"
)
func main() {
client, err := ent.Open("postgres", "...")
if err != nil {
log.Fatal(err)
}
// 注册全局钩子
client.Use(
// 日志钩子 - 记录所有操作
func(next ent.Mutator) ent.Mutator {
return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) {
start := time.Now()
// 执行操作
v, err := next.Mutate(ctx, m)
// 记录日志
log.Printf(
"Op=%s, Type=%s, Duration=%s, Error=%v",
m.Op().String(),
m.Type(),
time.Since(start),
err,
)
return v, err
})
},
// 审计钩子 - 记录操作者
func(next ent.Mutator) ent.Mutator {
return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) {
// 从context获取当前用户
userID, ok := ctx.Value("user_id").(int)
if ok {
// 根据Mutation类型设置审计字段
if m.Op().Is(ent.OpCreate) {
if setter, ok := m.(interface{ SetCreatedBy(int) }); ok {
setter.SetCreatedBy(userID)
}
}
if m.Op().Is(ent.OpUpdate | ent.OpUpdateOne) {
if setter, ok := m.(interface{ SetUpdatedBy(int) }); ok {
setter.SetUpdatedBy(userID)
}
}
}
return next.Mutate(ctx, m)
})
},
)
// 针对特定类型的钩子
client.User.Use(
func(next ent.Mutator) ent.Mutator {
return hook.UserFunc(func(ctx context.Context, m *ent.UserMutation) (ent.Value, error) {
// 密码加密
if password, exists := m.Password(); exists {
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
m.SetPassword(string(hashedPassword))
}
return next.Mutate(ctx, m)
})
},
)
}
6.3 条件Hooks
// 只在特定条件下执行的钩子
func ConditionalHook() ent.Hook {
hk := func(next ent.Mutator) ent.Mutator {
return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) {
// 执行验证逻辑
if um, ok := m.(*ent.UserMutation); ok {
email, exists := um.Email()
if exists && !isValidEmail(email) {
return nil, fmt.Errorf("invalid email format: %s", email)
}
}
return next.Mutate(ctx, m)
})
}
// 只在创建和更新时执行
return hook.On(hk, ent.OpCreate|ent.OpUpdate|ent.OpUpdateOne)
}
// 排除某些操作
func ExcludeHook() ent.Hook {
hk := func(next ent.Mutator) ent.Mutator {
return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) {
// 你的逻辑
return next.Mutate(ctx, m)
})
}
// 排除删除操作
return hook.Unless(hk, ent.OpDelete|ent.OpDeleteOne)
}
// 只对特定字段变更时执行
func FieldChangeHook() ent.Hook {
return hook.If(
func(next ent.Mutator) ent.Mutator {
return hook.UserFunc(func(ctx context.Context, m *ent.UserMutation) (ent.Value, error) {
oldEmail, _ := m.OldEmail(ctx)
newEmail, _ := m.Email()
// 发送邮箱变更通知
sendEmailChangeNotification(oldEmail, newEmail)
return next.Mutate(ctx, m)
})
},
hook.HasFields("email"), // 只有email字段变更时才执行
)
}
七、Privacy(隐私/访问控制)
Ent的Privacy层允许定义细粒度的访问控制规则,主要用于数据访问控制和权限管理,它在数据库操作级别提供细粒度的访问控制。以下是主要使用场景:
核心使用场景
- 多租户隔离 确保用户只能访问自己的数据。比如在SaaS应用中,用户A不能查询或修改用户B的数据。
// 示例:用户只能看到自己的文章
func FilterTenantRule() privacy.QueryRule {
return privacy.FilterFunc(func(ctx context.Context, f privacy.Filter) error {
userID := getUserIDFromContext(ctx)
f.Where(article.HasOwnerWith(user.ID(userID)))
return nil
})
}
- 基于角色的访问控制(RBAC) 根据用户角色限制操作权限。比如只有管理员可以删除某些资源,普通用户只能读取。
- 行级权限控制 根据数据的状态或属性控制访问。例如,已发布的文章所有人可见,草稿只有作者可见。
- 字段级权限 控制敏感字段的访问,如用户的手机号、邮箱等只有本人或管理员可见。
- 审计和合规 强制执行数据访问策略,确保符合GDPR等数据保护法规。
Privacy的两种规则类型
- QueryRule: 控制读取操作(Query, Get)
- MutationRule: 控制写入操作(Create, Update, Delete)
Privacy层的优势是将权限逻辑与业务代码分离,在ORM层统一处理,避免在每个查询中手动添加权限检查,减少安全漏洞风险。
7.1 配置Privacy
首先在生成配置中启用Privacy:
// ent/generate.go
package ent
//go:generate go run -mod=mod entgo.io/ent/cmd/ent generate --feature privacy ./schema
7.2 定义Privacy规则
package schema
import (
"context"
"entgo.io/ent"
"entgo.io/ent/privacy"
"entgo.io/ent/schema/field"
"myapp/ent/predicate"
)
type User struct {
ent.Schema
}
func (User) Fields() []ent.Field {
return []ent.Field{
field.String("name"),
field.String("email"),
field.Int("tenant_id"),
field.Enum("role").Values("admin", "user", "guest"),
}
}
func (User) Policy() ent.Policy {
return privacy.Policy{
// 查询规则
Query: privacy.QueryPolicy{
// 管理员可以查看所有用户
privacy.QueryRuleFunc(func(ctx context.Context, q ent.Query) error {
if isAdmin(ctx) {
return privacy.Allow
}
return privacy.Skip // 继续下一个规则
}),
// 普通用户只能查看同一租户的用户
privacy.QueryRuleFunc(func(ctx context.Context, q ent.Query) error {
tenantID := getTenantID(ctx)
if tenantID == 0 {
return privacy.Deny
}
// 添加租户过滤条件
userQuery := q.(*ent.UserQuery)
userQuery.Where(user.TenantIDEQ(tenantID))
return privacy.Allow
}),
},
// 变更规则(创建、更新、删除)
Mutation: privacy.MutationPolicy{
// 只有管理员可以创建用户
privacy.OnMutationOperation(
privacy.MutationRuleFunc(func(ctx context.Context, m ent.Mutation) error {
if isAdmin(ctx) {
return privacy.Allow
}
return privacy.Deny
}),
ent.OpCreate,
),
// 用户只能更新自己的信息
privacy.OnMutationOperation(
privacy.MutationRuleFunc(func(ctx context.Context, m ent.Mutation) error {
if isAdmin(ctx) {
return privacy.Allow
}
um := m.(*ent.UserMutation)
id, exists := um.ID()
if !exists {
return privacy.Deny
}
currentUserID := getCurrentUserID(ctx)
if id == currentUserID {
return privacy.Allow
}
return privacy.Deny
}),
ent.OpUpdate|ent.OpUpdateOne,
),
// 只有管理员可以删除用户
privacy.OnMutationOperation(
privacy.DenyMutationOperationRule(ent.OpDelete|ent.OpDeleteOne),
),
},
}
}
// 辅助函数
func isAdmin(ctx context.Context) bool {
role, ok := ctx.Value("role").(string)
return ok && role == "admin"
}
func getTenantID(ctx context.Context) int {
tenantID, _ := ctx.Value("tenant_id").(int)
return tenantID
}
func getCurrentUserID(ctx context.Context) int {
userID, _ := ctx.Value("user_id").(int)
return userID
}
7.3 多租户隔离规则
package rule
import (
"context"
"entgo.io/ent/privacy"
"myapp/ent"
"myapp/ent/predicate"
"myapp/ent/user"
)
// FilterTenantRule 自动添加租户过滤
type FilterTenantRule struct {
privacy.QueryMutationRule
}
func (f FilterTenantRule) EvalQuery(ctx context.Context, q ent.Query) error {
tenantID := getTenantID(ctx)
if tenantID == 0 {
return privacy.Denyf("missing tenant information")
}
// 根据查询类型添加过滤条件
switch q := q.(type) {
case *ent.UserQuery:
q.Where(user.TenantIDEQ(tenantID))
case *ent.PostQuery:
q.Where(post.TenantIDEQ(tenantID))
// ... 其他实体
}
return privacy.Skip
}
func (f FilterTenantRule) EvalMutation(ctx context.Context, m ent.Mutation) error {
tenantID := getTenantID(ctx)
if tenantID == 0 {
return privacy.Denyf("missing tenant information")
}
// 创建时自动设置租户ID
if m.Op().Is(ent.OpCreate) {
if setter, ok := m.(interface{ SetTenantID(int) }); ok {
setter.SetTenantID(tenantID)
}
}
return privacy.Skip
}
7.4 跳过Privacy检查
// 使用privacy.DecisionContext跳过检查
func AdminOperation(ctx context.Context, client *ent.Client) error {
// 允许所有操作
allowCtx := privacy.DecisionContext(ctx, privacy.Allow)
users, err := client.User.Query().All(allowCtx)
// ...
return err
}
八、拦截器(Interceptors)
拦截器是Ent v0.11引入的功能,用于拦截查询操作。
package main
import (
"context"
"log"
"entgo.io/ent"
entgo "myapp/ent"
"myapp/ent/intercept"
)
func main() {
client, _ := entgo.Open("postgres", "...")
// 查询拦截器
client.Intercept(
// 日志拦截器
intercept.Func(func(ctx context.Context, q intercept.Query) error {
log.Printf("Query: Type=%s", q.Type())
return nil
}),
// 软删除过滤器 - 自动过滤已删除的记录
intercept.TraverseFunc(func(ctx context.Context, q intercept.Query) error {
// 检查是否需要跳过软删除过滤
if skipSoftDelete(ctx) {
return nil
}
// 添加deleted_at IS NULL条件
switch q := q.(type) {
case *entgo.UserQuery:
q.Where(user.DeletedAtIsNil())
case *entgo.PostQuery:
q.Where(post.DeletedAtIsNil())
}
return nil
}),
)
// 特定类型的拦截器
client.User.Intercept(
intercept.TraverseUser(func(ctx context.Context, q *entgo.UserQuery) error {
// 用户特定的拦截逻辑
log.Println("Querying users...")
return nil
}),
)
}
func skipSoftDelete(ctx context.Context) bool {
skip, _ := ctx.Value("skip_soft_delete").(bool)
return skip
}
九、数据库迁移
9.1 自动迁移
import (
"context"
"entgo.io/ent/dialect/migrate"
"myapp/ent"
)
func AutoMigration(ctx context.Context, client *ent.Client) error {
// 基本迁移
err := client.Schema.Create(ctx)
// 带选项的迁移
err = client.Schema.Create(ctx,
// 创建外键约束
migrate.WithForeignKeys(true),
// 创建索引
migrate.WithDropIndex(true),
migrate.WithDropColumn(true),
// 全局唯一ID(用于联邦查询)
migrate.WithGlobalUniqueID(true),
)
return err
}
9.2 版本化迁移(Atlas集成)
Ent与Atlas集成提供版本化迁移能力:
# 安装Atlas CLI
curl -sSf https://atlasgo.sh | sh
# 或使用Go安装
go install ariga.io/atlas/cmd/atlas@latest
配置版本化迁移:
// ent/migrate/main.go
//go:build ignore
package main
import (
"context"
"log"
"os"
"myapp/ent/migrate"
atlas "ariga.io/atlas/sql/migrate"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql/schema"
_ "github.com/lib/pq"
)
func main() {
ctx := context.Background()
// 创建迁移目录
dir, err := atlas.NewLocalDir("ent/migrate/migrations")
if err != nil {
log.Fatalf("failed creating atlas migration directory: %v", err)
}
// 迁移选项
opts := []schema.MigrateOption{
schema.WithDir(dir),
schema.WithMigrationMode(schema.ModeReplay),
schema.WithDialect(dialect.Postgres),
schema.WithFormatter(atlas.DefaultFormatter),
}
// 生成迁移文件
err = migrate.NamedDiff(ctx, os.Getenv("DATABASE_URL"), os.Args[1], opts...)
if err != nil {
log.Fatalf("failed generating migration file: %v", err)
}
log.Println("Migration file generated successfully")
}
使用Atlas命令:
# 生成迁移文件
go run -mod=mod ent/migrate/main.go add_user_table
# 应用迁移
atlas migrate apply \
--dir "file://ent/migrate/migrations" \
--url "postgres://localhost:5432/myapp?sslmode=disable"
# 查看迁移状态
atlas migrate status \
--dir "file://ent/migrate/migrations" \
--url "postgres://localhost:5432/myapp?sslmode=disable"
# 回滚迁移
atlas migrate down \
--dir "file://ent/migrate/migrations" \
--url "postgres://localhost:5432/myapp?sslmode=disable"
9.3 迁移钩子
func MigrationWithHooks(ctx context.Context, client *ent.Client) error {
err := client.Schema.Create(
ctx,
migrate.WithHooks(func(next schema.Creator) schema.Creator {
return schema.CreateFunc(func(ctx context.Context, tables ...*schema.Table) error {
// 迁移前的操作
log.Println("Starting migration...")
// 执行迁移
if err := next.Create(ctx, tables...); err != nil {
return err
}
// 迁移后的操作
log.Println("Migration completed!")
return nil
})
}),
)
return err
}
9.4 检查Schema差异
func CheckSchemaDiff(ctx context.Context, client *ent.Client) error {
// 获取当前Schema和数据库Schema的差异
changes, err := client.Schema.Diff(ctx)
if err != nil {
return err
}
for _, change := range changes {
log.Printf("Change: %s\n", change)
}
return nil
}
十、高级功能
10.1 自定义ID生成
package schema
import (
"entgo.io/ent"
"entgo.io/ent/schema/field"
"github.com/google/uuid"
"github.com/sony/sonyflake"
)
var sf *sonyflake.Sonyflake
func init() {
sf = sonyflake.NewSonyflake(sonyflake.Settings{})
}
// UUID作为ID
type User struct {
ent.Schema
}
func (User) Fields() []ent.Field {
return []ent.Field{
field.UUID("id", uuid.UUID{}).
Default(uuid.New).
Immutable(),
field.String("name"),
}
}
// Snowflake ID
type Post struct {
ent.Schema
}
func (Post) Fields() []ent.Field {
return []ent.Field{
field.Int64("id").
DefaultFunc(func() int64 {
id, _ := sf.NextID()
return int64(id)
}).
Immutable(),
field.String("title"),
}
}
// 自定义字符串ID
type Order struct {
ent.Schema
}
func (Order) Fields() []ent.Field {
return []ent.Field{
field.String("id").
DefaultFunc(func() string {
return fmt.Sprintf("ORD-%s", uuid.New().String()[:8])
}).
Immutable(),
}
}
10.2 自定义谓词(Predicates)
package customsql
import (
"entgo.io/ent/dialect/sql"
"myapp/ent/predicate"
"myapp/ent/user"
)
// FullTextSearch 全文搜索谓词
func FullTextSearch(term string) predicate.User {
return predicate.User(func(s *sql.Selector) {
s.Where(sql.ExprP(
"to_tsvector('english', name || ' ' || COALESCE(bio, '')) @@ plainto_tsquery('english', $1)",
term,
))
})
}
// WithinDistance 地理位置范围查询
func WithinDistance(lat, lng, distanceKm float64) predicate.User {
return predicate.User(func(s *sql.Selector) {
s.Where(sql.ExprP(
`ST_DWithin(
location::geography,
ST_SetSRID(ST_MakePoint($1, $2), 4326)::geography,
$3
)`,
lng, lat, distanceKm*1000,
))
})
}
// JSONContains JSON字段包含查询
func JSONContains(field, key, value string) predicate.User {
return predicate.User(func(s *sql.Selector) {
s.Where(sql.ExprP(
fmt.Sprintf("%s->>$1 = $2", field),
key, value,
))
})
}
// 使用示例
func SearchUsers(ctx context.Context, client *ent.Client) ([]*ent.User, error) {
return client.User.
Query().
Where(
user.And(
user.StatusEQ("active"),
FullTextSearch("engineer"),
WithinDistance(37.7749, -122.4194, 50), // 旧金山50公里内
),
).
All(ctx)
}
10.3 原生SQL查询
package main
import (
"context"
"database/sql"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
)
// 使用原生SQL查询
func RawSQLQuery(ctx context.Context, client *ent.Client) error {
// 获取底层数据库连接
var results []struct {
ID int `json:"id"`
Name string `json:"name"`
Count int `json:"count"`
}
err := client.User.
Query().
Modify(func(s *entsql.Selector) {
s.Select(
s.C(user.FieldID),
s.C(user.FieldName),
entsql.As(entsql.Count("*"), "count"),
).
From(s.Table()).
LeftJoin(
entsql.Table(post.Table),
).
On(
s.C(user.FieldID),
entsql.Table(post.Table).C(post.FieldAuthorID),
).
GroupBy(s.C(user.FieldID))
}).
Scan(ctx, &results)
return err
}
// 直接执行原生SQL
func ExecRawSQL(ctx context.Context, client *ent.Client) error {
// 获取驱动
drv := client.Driver()
// 执行查询
rows, err := drv.Query(ctx,
"SELECT id, name FROM users WHERE status = $1 LIMIT $2",
[]interface{}{"active", 10},
)
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var id int
var name string
if err := rows.Scan(&id, &name); err != nil {
return err
}
fmt.Printf("User: %d - %s\n", id, name)
}
return rows.Err()
}
// 使用sql.DB直接操作
func DirectDBAccess(ctx context.Context, client *ent.Client) error {
// 获取*sql.DB
db := client.DB()
// 执行原生查询
rows, err := db.QueryContext(ctx,
`SELECT u.id, u.name, COUNT(p.id) as post_count
FROM users u
LEFT JOIN posts p ON u.id = p.author_id
GROUP BY u.id
ORDER BY post_count DESC
LIMIT 10`)
if err != nil {
return err
}
defer rows.Close()
// 处理结果...
return nil
}
10.4 分页器
package pagination
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"myapp/ent"
"myapp/ent/user"
)
// PageInfo 分页信息
type PageInfo struct {
HasNextPage bool `json:"has_next_page"`
HasPreviousPage bool `json:"has_previous_page"`
StartCursor string `json:"start_cursor,omitempty"`
EndCursor string `json:"end_cursor,omitempty"`
TotalCount int `json:"total_count"`
}
// UserEdge 用户边缘
type UserEdge struct {
Node *ent.User `json:"node"`
Cursor string `json:"cursor"`
}
// UserConnection 用户连接(游标分页结果)
type UserConnection struct {
Edges []*UserEdge `json:"edges"`
PageInfo *PageInfo `json:"page_info"`
}
// Cursor 游标结构
type Cursor struct {
ID int `json:"id"`
Timestamp int64 `json:"ts"`
}
// EncodeCursor 编码游标
func EncodeCursor(c Cursor) string {
data, _ := json.Marshal(c)
return base64.StdEncoding.EncodeToString(data)
}
// DecodeCursor 解码游标
func DecodeCursor(s string) (Cursor, error) {
var c Cursor
data, err := base64.StdEncoding.DecodeString(s)
if err != nil {
return c, err
}
err = json.Unmarshal(data, &c)
return c, err
}
// PaginateUsers 游标分页查询用户
func PaginateUsers(
ctx context.Context,
client *ent.Client,
first *int,
after *string,
last *int,
before *string,
) (*UserConnection, error) {
query := client.User.Query().
Where(user.StatusEQ("active")).
Order(ent.Desc(user.FieldCreatedAt))
// 获取总数
totalCount, err := query.Clone().Count(ctx)
if err != nil {
return nil, err
}
// 处理after游标
if after != nil {
cursor, err := DecodeCursor(*after)
if err != nil {
return nil, err
}
query = query.Where(user.IDLT(cursor.ID))
}
// 处理before游标
if before != nil {
cursor, err := DecodeCursor(*before)
if err != nil {
return nil, err
}
query = query.Where(user.IDGT(cursor.ID))
}
// 设置限制(多取一个用于判断是否有下一页)
limit := 10
if first != nil {
limit = *first
}
query = query.Limit(limit + 1)
users, err := query.All(ctx)
if err != nil {
return nil, err
}
// 构建结果
hasNextPage := len(users) > limit
if hasNextPage {
users = users[:limit]
}
edges := make([]*UserEdge, len(users))
for i, u := range users {
edges[i] = &UserEdge{
Node: u,
Cursor: EncodeCursor(Cursor{
ID: u.ID,
Timestamp: u.CreatedAt.Unix(),
}),
}
}
var startCursor, endCursor string
if len(edges) > 0 {
startCursor = edges[0].Cursor
endCursor = edges[len(edges)-1].Cursor
}
return &UserConnection{
Edges: edges,
PageInfo: &PageInfo{
HasNextPage: hasNextPage,
HasPreviousPage: after != nil,
StartCursor: startCursor,
EndCursor: endCursor,
TotalCount: totalCount,
},
}, nil
}
// OffsetPagination 偏移分页
type OffsetPagination struct {
Page int `json:"page"`
PageSize int `json:"page_size"`
Total int `json:"total"`
Pages int `json:"pages"`
}
type PaginatedUsers struct {
Users []*ent.User `json:"users"`
Pagination *OffsetPagination `json:"pagination"`
}
func PaginateUsersOffset(
ctx context.Context,
client *ent.Client,
page, pageSize int,
) (*PaginatedUsers, error) {
if page < 1 {
page = 1
}
if pageSize < 1 || pageSize > 100 {
pageSize = 10
}
query := client.User.Query().Where(user.StatusEQ("active"))
// 获取总数
total, err := query.Clone().Count(ctx)
if err != nil {
return nil, err
}
// 计算分页
offset := (page - 1) * pageSize
pages := (total + pageSize - 1) / pageSize
users, err := query.
Order(ent.Desc(user.FieldCreatedAt)).
Offset(offset).
Limit(pageSize).
All(ctx)
if err != nil {
return nil, err
}
return &PaginatedUsers{
Users: users,
Pagination: &OffsetPagination{
Page: page,
PageSize: pageSize,
Total: total,
Pages: pages,
},
}, nil
}
10.5 软删除实现
package schema
import (
"context"
"time"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/mixin"
gen "myapp/ent"
"myapp/ent/hook"
"myapp/ent/intercept"
)
// SoftDeleteMixin 软删除混入
type SoftDeleteMixin struct {
mixin.Schema
}
func (SoftDeleteMixin) Fields() []ent.Field {
return []ent.Field{
field.Time("deleted_at").
Optional().
Nillable(),
}
}
// SoftDeleteKey 用于context的键
type softDeleteKey struct{}
// SkipSoftDelete 返回跳过软删除过滤的context
func SkipSoftDelete(ctx context.Context) context.Context {
return context.WithValue(ctx, softDeleteKey{}, true)
}
// 拦截器:查询时自动过滤已删除的记录
func (SoftDeleteMixin) Interceptors() []ent.Interceptor {
return []ent.Interceptor{
intercept.TraverseFunc(func(ctx context.Context, q intercept.Query) error {
// 检查是否跳过软删除过滤
if skip, _ := ctx.Value(softDeleteKey{}).(bool); skip {
return nil
}
// 使用类型断言添加过滤条件
sf, ok := q.(interface {
WhereP(...func(*sql.Selector))
})
if !ok {
return nil
}
sf.WhereP(func(s *sql.Selector) {
s.Where(sql.IsNull(s.C("deleted_at")))
})
return nil
}),
}
}
// 钩子:将删除操作转换为软删除
func (SoftDeleteMixin) Hooks() []ent.Hook {
return []ent.Hook{
hook.On(
func(next ent.Mutator) ent.Mutator {
return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) {
// 检查是否强制硬删除
if skip, _ := ctx.Value(softDeleteKey{}).(bool); skip {
return next.Mutate(ctx, m)
}
// 获取删除的实体类型和ID
mx, ok := m.(interface {
SetDeletedAt(time.Time)
WhereP(...func(*sql.Selector))
SetOp(ent.Op)
Client() *gen.Client
})
if !ok {
return next.Mutate(ctx, m)
}
// 将删除转换为更新
mx.SetDeletedAt(time.Now())
mx.SetOp(ent.OpUpdate)
return next.Mutate(ctx, m)
})
},
ent.OpDelete|ent.OpDeleteOne,
),
}
}
// User实体使用软删除
type User struct {
ent.Schema
}
func (User) Mixin() []ent.Mixin {
return []ent.Mixin{
SoftDeleteMixin{},
}
}
// 使用示例
func SoftDeleteExample(ctx context.Context, client *ent.Client) error {
// 普通删除(软删除)
err := client.User.DeleteOneID(1).Exec(ctx)
// 查询时自动过滤已删除的记录
users, err := client.User.Query().All(ctx)
// 查询包括已删除的记录
allUsers, err := client.User.Query().All(SkipSoftDelete(ctx))
// 强制硬删除
err = client.User.DeleteOneID(1).Exec(SkipSoftDelete(ctx))
// 恢复已删除的记录
err = client.User.
UpdateOneID(1).
ClearDeletedAt().
Exec(SkipSoftDelete(ctx))
return err
}
10.6 乐观锁实现
package schema
import (
"context"
"fmt"
"entgo.io/ent"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/mixin"
gen "myapp/ent"
"myapp/ent/hook"
)
// OptimisticLockMixin 乐观锁混入
type OptimisticLockMixin struct {
mixin.Schema
}
func (OptimisticLockMixin) Fields() []ent.Field {
return []ent.Field{
field.Int64("version").
Default(1).
Comment("乐观锁版本号"),
}
}
func (OptimisticLockMixin) Hooks() []ent.Hook {
return []ent.Hook{
hook.On(
func(next ent.Mutator) ent.Mutator {
return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) {
// 获取旧版本号
mx, ok := m.(interface {
OldVersion(ctx context.Context) (int64, error)
SetVersion(int64)
WhereP(...func(*sql.Selector))
})
if !ok {
return next.Mutate(ctx, m)
}
oldVersion, err := mx.OldVersion(ctx)
if err != nil {
return nil, err
}
// 添加版本号条件
mx.WhereP(func(s *sql.Selector) {
s.Where(sql.EQ(s.C("version"), oldVersion))
})
// 设置新版本号
mx.SetVersion(oldVersion + 1)
// 执行更新
v, err := next.Mutate(ctx, m)
if err != nil {
return nil, err
}
// 检查是否更新成功(返回的affected rows)
// 如果版本号不匹配,会返回0行被更新
return v, nil
})
},
ent.OpUpdate|ent.OpUpdateOne,
),
}
}
// 使用乐观锁的更新函数
func UpdateWithOptimisticLock(ctx context.Context, client *ent.Client, userID int, name string) error {
// 获取当前实体
user, err := client.User.Get(ctx, userID)
if err != nil {
return err
}
// 尝试更新
affected, err := client.User.
Update().
Where(
user.IDEQ(userID),
user.VersionEQ(user.Version),
).
SetName(name).
SetVersion(user.Version + 1).
Save(ctx)
if err != nil {
return err
}
if affected == 0 {
return fmt.Errorf("optimistic lock conflict: user %d has been modified", userID)
}
return nil
}
// 带重试的乐观锁更新
func UpdateWithRetry(ctx context.Context, client *ent.Client, userID int, updateFn func(*ent.User) error, maxRetries int) error {
for i := 0; i < maxRetries; i++ {
user, err := client.User.Get(ctx, userID)
if err != nil {
return err
}
// 应用更新
if err := updateFn(user); err != nil {
return err
}
// 尝试保存
affected, err := client.User.
Update().
Where(
user.IDEQ(userID),
user.VersionEQ(user.Version),
).
SetName(user.Name).
SetVersion(user.Version + 1).
Save(ctx)
if err != nil {
return err
}
if affected > 0 {
return nil // 更新成功
}
// 版本冲突,重试
log.Printf("Optimistic lock conflict, retry %d/%d", i+1, maxRetries)
}
return fmt.Errorf("failed to update after %d retries", maxRetries)
}
10.7 Redis缓存集成
📋 使用场景
为什么需要缓存?
- 🚀 提升性能 - 减少数据库查询,降低响应时间
- 💰 降低成本 - 减轻数据库负载,节省资源
- 📈 提高并发 - 处理更多请求
- 🔄 减少延迟 - 热点数据快速访问
适用场景:
- 读多写少的数据 - 用户信息、配置信息
- 计算密集型查询 - 统计数据、聚合结果
- 频繁访问的数据 - 热门商品、文章
- 会话数据 - 登录状态、购物车
- 临时数据 - 验证码、限流计数
实际业务例子:
- 电商系统:商品详情、分类信息
- 社交平台:用户资料、关注列表
- 内容平台:文章详情、评论列表
安装依赖
go get github.com/redis/go-redis/v9
基础缓存实现
package cache
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/redis/go-redis/v9"
"myapp/ent"
)
type Cache struct {
redis *redis.Client
client *ent.Client
}
func NewCache(redisClient *redis.Client, entClient *ent.Client) *Cache {
return &Cache{
redis: redisClient,
client: entClient,
}
}
// 生成缓存键
func (c *Cache) userKey(id int) string {
return fmt.Sprintf("user:%d", id)
}
// GetUser 从缓存获取用户,缓存未命中时从数据库加载
func (c *Cache) GetUser(ctx context.Context, id int) (*ent.User, error) {
key := c.userKey(id)
// 1. 尝试从Redis获取
data, err := c.redis.Get(ctx, key).Bytes()
if err == nil {
// 缓存命中,反序列化
var user ent.User
if err := json.Unmarshal(data, &user); err == nil {
return &user, nil
}
}
// 2. 缓存未命中,从数据库查询
user, err := c.client.User.Get(ctx, id)
if err != nil {
return nil, err
}
// 3. 写入缓存
if data, err := json.Marshal(user); err == nil {
c.redis.Set(ctx, key, data, 5*time.Minute)
}
return user, nil
}
// InvalidateUser 删除用户缓存
func (c *Cache) InvalidateUser(ctx context.Context, id int) error {
return c.redis.Del(ctx, c.userKey(id)).Err()
}
// SetUser 设置用户缓存
func (c *Cache) SetUser(ctx context.Context, user *ent.User, ttl time.Duration) error {
key := c.userKey(user.ID)
data, err := json.Marshal(user)
if err != nil {
return err
}
return c.redis.Set(ctx, key, data, ttl).Err()
}
Cache-Aside模式(推荐)
package repository
import (
"context"
"encoding/json"
"time"
"github.com/redis/go-redis/v9"
"myapp/ent"
"myapp/ent/user"
)
type UserRepository struct {
db *ent.Client
cache *redis.Client
}
func NewUserRepository(db *ent.Client, cache *redis.Client) *UserRepository {
return &UserRepository{db: db, cache: cache}
}
// GetByID 通过ID获取用户(带缓存)
func (r *UserRepository) GetByID(ctx context.Context, id int) (*ent.User, error) {
key := fmt.Sprintf("user:%d", id)
// 1. 尝试从缓存获取
val, err := r.cache.Get(ctx, key).Result()
if err == nil {
var u ent.User
if err := json.Unmarshal([]byte(val), &u); err == nil {
return &u, nil
}
}
// 2. 缓存未命中,查询数据库
u, err := r.db.User.Get(ctx, id)
if err != nil {
return nil, err
}
// 3. 写入缓存(异步,失败不影响结果)
go func() {
data, _ := json.Marshal(u)
r.cache.Set(context.Background(), key, data, 10*time.Minute)
}()
return u, nil
}
// Update 更新用户(删除缓存)
func (r *UserRepository) Update(ctx context.Context, id int, updateFn func(*ent.UserUpdateOne) *ent.UserUpdateOne) (*ent.User, error) {
// 1. 更新数据库
u, err := updateFn(r.db.User.UpdateOneID(id)).Save(ctx)
if err != nil {
return nil, err
}
// 2. 删除缓存
key := fmt.Sprintf("user:%d", id)
r.cache.Del(ctx, key)
return u, nil
}
// Delete 删除用户(删除缓存)
func (r *UserRepository) Delete(ctx context.Context, id int) error {
// 1. 删除数据库记录
if err := r.db.User.DeleteOneID(id).Exec(ctx); err != nil {
return err
}
// 2. 删除缓存
key := fmt.Sprintf("user:%d", id)
r.cache.Del(ctx, key)
return nil
}
// GetActiveUsers 获取活跃用户列表(带缓存)
func (r *UserRepository) GetActiveUsers(ctx context.Context, limit int) ([]*ent.User, error) {
key := fmt.Sprintf("users:active:%d", limit)
// 1. 尝试从缓存获取
val, err := r.cache.Get(ctx, key).Result()
if err == nil {
var users []*ent.User
if err := json.Unmarshal([]byte(val), &users); err == nil {
return users, nil
}
}
// 2. 查询数据库
users, err := r.db.User.
Query().
Where(user.StatusEQ("active")).
Limit(limit).
All(ctx)
if err != nil {
return nil, err
}
// 3. 写入缓存
go func() {
data, _ := json.Marshal(users)
r.cache.Set(context.Background(), key, data, 5*time.Minute)
}()
return users, nil
}
使用Ent Hook自动管理缓存
package schema
import (
"context"
"fmt"
"github.com/redis/go-redis/v9"
"entgo.io/ent"
gen "myapp/ent"
"myapp/ent/hook"
)
// 获取Redis客户端(从context或全局变量)
func getRedisClient(ctx context.Context) *redis.Client {
if client, ok := ctx.Value("redis").(*redis.Client); ok {
return client
}
return globalRedisClient // 或使用全局客户端
}
// User schema with cache hooks
type User struct {
ent.Schema
}
func (User) Hooks() []ent.Hook {
return []ent.Hook{
// 创建后写入缓存
hook.On(
func(next ent.Mutator) ent.Mutator {
return hook.UserFunc(func(ctx context.Context, m *gen.UserMutation) (ent.Value, error) {
// 执行创建
v, err := next.Mutate(ctx, m)
if err != nil {
return v, err
}
// 获取创建的用户
user := v.(*ent.User)
// 写入缓存
if rdb := getRedisClient(ctx); rdb != nil {
key := fmt.Sprintf("user:%d", user.ID)
data, _ := json.Marshal(user)
rdb.Set(ctx, key, data, 10*time.Minute)
}
return v, nil
})
},
ent.OpCreate,
),
// 更新后删除缓存
hook.On(
func(next ent.Mutator) ent.Mutator {
return hook.UserFunc(func(ctx context.Context, m *gen.UserMutation) (ent.Value, error) {
// 获取用户ID
id, exists := m.ID()
// 执行更新
v, err := next.Mutate(ctx, m)
if err != nil {
return v, err
}
// 删除缓存
if exists {
if rdb := getRedisClient(ctx); rdb != nil {
key := fmt.Sprintf("user:%d", id)
rdb.Del(ctx, key)
}
}
return v, nil
})
},
ent.OpUpdate|ent.OpUpdateOne,
),
// 删除后清除缓存
hook.On(
func(next ent.Mutator) ent.Mutator {
return hook.UserFunc(func(ctx context.Context, m *gen.UserMutation) (ent.Value, error) {
// 获取用户ID
id, exists := m.ID()
// 执行删除
v, err := next.Mutate(ctx, m)
if err != nil {
return v, err
}
// 删除缓存
if exists {
if rdb := getRedisClient(ctx); rdb != nil {
key := fmt.Sprintf("user:%d", id)
rdb.Del(ctx, key)
}
}
return v, nil
})
},
ent.OpDelete|ent.OpDeleteOne,
),
}
}
缓存预热
package cache
import (
"context"
"encoding/json"
"time"
"github.com/redis/go-redis/v9"
"myapp/ent"
"myapp/ent/user"
)
// WarmUpCache 预热缓存
func WarmUpCache(ctx context.Context, db *ent.Client, cache *redis.Client) error {
// 1. 加载热门用户
users, err := db.User.
Query().
Where(user.StatusEQ("active")).
Order(ent.Desc(user.FieldLoginCount)).
Limit(100).
All(ctx)
if err != nil {
return err
}
// 2. 批量写入缓存
pipe := cache.Pipeline()
for _, u := range users {
key := fmt.Sprintf("user:%d", u.ID)
data, _ := json.Marshal(u)
pipe.Set(ctx, key, data, 10*time.Minute)
}
_, err = pipe.Exec(ctx)
return err
}
缓存失效策略
package cache
import (
"context"
"fmt"
"time"
"github.com/redis/go-redis/v9"
)
type InvalidationStrategy struct {
cache *redis.Client
}
// InvalidateUserRelatedCaches 删除用户相关的所有缓存
func (s *InvalidationStrategy) InvalidateUserRelatedCaches(ctx context.Context, userID int) error {
patterns := []string{
fmt.Sprintf("user:%d", userID),
fmt.Sprintf("user:%d:posts:*", userID),
fmt.Sprintf("user:%d:profile", userID),
"users:active:*", // 列表缓存也要删除
}
for _, pattern := range patterns {
keys, err := s.cache.Keys(ctx, pattern).Result()
if err != nil {
continue
}
if len(keys) > 0 {
s.cache.Del(ctx, keys...)
}
}
return nil
}
// SetWithTags 设置带标签的缓存
func (s *InvalidationStrategy) SetWithTags(ctx context.Context, key string, value interface{}, ttl time.Duration, tags ...string) error {
// 1. 设置主键
data, _ := json.Marshal(value)
if err := s.cache.Set(ctx, key, data, ttl).Err(); err != nil {
return err
}
// 2. 为每个标签建立索引
for _, tag := range tags {
tagKey := fmt.Sprintf("tag:%s", tag)
s.cache.SAdd(ctx, tagKey, key)
s.cache.Expire(ctx, tagKey, ttl)
}
return nil
}
// InvalidateByTag 通过标签删除缓存
func (s *InvalidationStrategy) InvalidateByTag(ctx context.Context, tag string) error {
tagKey := fmt.Sprintf("tag:%s", tag)
// 获取该标签下的所有key
keys, err := s.cache.SMembers(ctx, tagKey).Result()
if err != nil {
return err
}
// 删除所有相关key
if len(keys) > 0 {
keys = append(keys, tagKey)
s.cache.Del(ctx, keys...)
}
return nil
}
防止缓存穿透
package cache
import (
"context"
"encoding/json"
"errors"
"time"
"github.com/redis/go-redis/v9"
"myapp/ent"
)
var ErrNotFound = errors.New("not found")
// GetUserWithNullCache 使用空值缓存防止穿透
func GetUserWithNullCache(ctx context.Context, db *ent.Client, cache *redis.Client, id int) (*ent.User, error) {
key := fmt.Sprintf("user:%d", id)
// 1. 尝试从缓存获取
val, err := cache.Get(ctx, key).Result()
if err == nil {
// 检查是否是空值缓存
if val == "null" {
return nil, ErrNotFound
}
var user ent.User
if err := json.Unmarshal([]byte(val), &user); err == nil {
return &user, nil
}
}
// 2. 查询数据库
user, err := db.User.Get(ctx, id)
if err != nil {
// 如果不存在,缓存一个空值(短时间)
if ent.IsNotFound(err) {
cache.Set(ctx, key, "null", 1*time.Minute)
}
return nil, err
}
// 3. 缓存正常数据
data, _ := json.Marshal(user)
cache.Set(ctx, key, data, 10*time.Minute)
return user, nil
}
防止缓存雪崩
package cache
import (
"math/rand"
"time"
)
// RandomTTL 返回带随机抖动的TTL,防止同时过期
func RandomTTL(base time.Duration, jitter float64) time.Duration {
if jitter <= 0 || jitter >= 1 {
return base
}
// 添加±jitter范围的随机时间
delta := time.Duration(float64(base) * jitter)
offset := time.Duration(rand.Int63n(int64(delta*2))) - delta
return base + offset
}
// 使用示例
func SetUserWithJitter(ctx context.Context, cache *redis.Client, user *ent.User) error {
key := fmt.Sprintf("user:%d", user.ID)
data, _ := json.Marshal(user)
// 基础TTL为10分钟,添加±20%的抖动
ttl := RandomTTL(10*time.Minute, 0.2)
return cache.Set(ctx, key, data, ttl).Err()
}
防止缓存击穿(热点key)
package cache
import (
"context"
"encoding/json"
"sync"
"time"
"github.com/redis/go-redis/v9"
"myapp/ent"
)
type SingleFlight struct {
mu sync.Mutex
calls map[string]*call
}
type call struct {
wg sync.WaitGroup
val *ent.User
err error
}
func NewSingleFlight() *SingleFlight {
return &SingleFlight{
calls: make(map[string]*call),
}
}
// GetUser 使用singleflight防止缓存击穿
func (sf *SingleFlight) GetUser(ctx context.Context, db *ent.Client, cache *redis.Client, id int) (*ent.User, error) {
key := fmt.Sprintf("user:%d", id)
// 1. 尝试从缓存获取
val, err := cache.Get(ctx, key).Result()
if err == nil {
var user ent.User
if err := json.Unmarshal([]byte(val), &user); err == nil {
return &user, nil
}
}
// 2. 缓存未命中,使用singleflight
sf.mu.Lock()
if c, ok := sf.calls[key]; ok {
// 已有其他goroutine在加载,等待结果
sf.mu.Unlock()
c.wg.Wait()
return c.val, c.err
}
// 创建新的call
c := &call{}
c.wg.Add(1)
sf.calls[key] = c
sf.mu.Unlock()
// 3. 从数据库加载
c.val, c.err = db.User.Get(ctx, id)
if c.err == nil {
// 写入缓存
data, _ := json.Marshal(c.val)
cache.Set(ctx, key, data, 10*time.Minute)
}
// 4. 通知等待的goroutine
c.wg.Done()
sf.mu.Lock()
delete(sf.calls, key)
sf.mu.Unlock()
return c.val, c.err
}
完整示例:带缓存的服务层
package service
import (
"context"
"time"
"github.com/redis/go-redis/v9"
"myapp/cache"
"myapp/ent"
"myapp/ent/user"
)
type UserService struct {
db *ent.Client
cache *redis.Client
repository *cache.UserRepository
sf *cache.SingleFlight
}
func NewUserService(db *ent.Client, cache *redis.Client) *UserService {
return &UserService{
db: db,
cache: cache,
repository: cache.NewUserRepository(db, cache),
sf: cache.NewSingleFlight(),
}
}
// GetUser 获取用户(使用所有缓存策略)
func (s *UserService) GetUser(ctx context.Context, id int) (*ent.User, error) {
// 使用singleflight + 空值缓存 + 随机TTL
return s.sf.GetUser(ctx, s.db, s.cache, id)
}
// UpdateUser 更新用户(自动失效缓存)
func (s *UserService) UpdateUser(ctx context.Context, id int, name string) (*ent.User, error) {
return s.repository.Update(ctx, id, func(u *ent.UserUpdateOne) *ent.UserUpdateOne {
return u.SetName(name)
})
}
// GetActiveUsers 获取活跃用户列表
func (s *UserService) GetActiveUsers(ctx context.Context, limit int) ([]*ent.User, error) {
return s.repository.GetActiveUsers(ctx, limit)
}
// 启动时预热缓存
func (s *UserService) WarmUp(ctx context.Context) error {
return cache.WarmUpCache(ctx, s.db, s.cache)
}
缓存最佳实践总结
1. 缓存策略选择:
- Cache-Aside(旁路缓存):最常用,应用负责缓存管理
- Read-Through:读穿透,由缓存层负责加载数据
- Write-Through:写穿透,同步写缓存和数据库
- Write-Behind:异步写,先写缓存,异步写数据库
2. 缓存更新策略:
- 更新时删除缓存(推荐)
- 更新时更新缓存(数据一致性要求高时)
- 定时失效(适合计算密集型数据)
3. 三大问题解决:
- 缓存穿透:布隆过滤器 + 空值缓存
- 缓存击穿:Singleflight + 互斥锁
- 缓存雪崩:随机TTL + 多级缓存
4. 监控指标:
- 缓存命中率(Hit Rate)
- 缓存延迟(Latency)
- 内存使用量
- 过期Key数量
十一、GraphQL集成
Ent原生支持GraphQL,可以自动生成GraphQL schema和resolver。
11.1 配置GraphQL生成
// ent/entc.go
//go:build ignore
package main
import (
"log"
"entgo.io/contrib/entgql"
"entgo.io/ent/entc"
"entgo.io/ent/entc/gen"
)
func main() {
ex, err := entgql.NewExtension(
// 生成GQL schema
entgql.WithSchemaGenerator(),
entgql.WithSchemaPath("graph/ent.graphql"),
// 配置自动生成的类型
entgql.WithConfigPath("gqlgen.yml"),
// 启用Relay规范
entgql.WithRelaySpec(true),
// Node接口
entgql.WithNodeDescriptor(true),
)
if err != nil {
log.Fatalf("creating entgql extension: %v", err)
}
opts := []entc.Option{
entc.Extensions(ex),
}
if err := entc.Generate("./ent/schema", &gen.Config{}, opts...); err != nil {
log.Fatalf("running ent codegen: %v", err)
}
}
11.2 Schema添加GraphQL注解
package schema
import (
"entgo.io/contrib/entgql"
"entgo.io/ent"
"entgo.io/ent/schema"
"entgo.io/ent/schema/edge"
"entgo.io/ent/schema/field"
)
type User struct {
ent.Schema
}
func (User) Fields() []ent.Field {
return []ent.Field{
field.String("name").
Annotations(
entgql.OrderField("NAME"),
),
field.String("email").
Annotations(
entgql.OrderField("EMAIL"),
),
field.Time("created_at").
Annotations(
entgql.OrderField("CREATED_AT"),
),
field.Enum("status").
Values("active", "inactive").
Annotations(
entgql.Type("UserStatus"),
),
}
}
func (User) Edges() []ent.Edge {
return []ent.Edge{
edge.To("posts", Post.Type).
Annotations(
entgql.RelayConnection(),
entgql.OrderField("POSTS_COUNT"),
),
}
}
func (User) Annotations() []schema.Annotation {
return []schema.Annotation{
entgql.RelayConnection(),
entgql.QueryField(),
entgql.Mutations(
entgql.MutationCreate(),
entgql.MutationUpdate(),
),
}
}
11.3 GraphQL Resolver实现
package resolver
import (
"context"
"myapp/ent"
"myapp/graph/generated"
)
type Resolver struct {
Client *ent.Client
}
func NewSchema(client *ent.Client) *generated.Config {
return &generated.Config{
Resolvers: &Resolver{Client: client},
}
}
// Query resolver
type queryResolver struct{ *Resolver }
func (r *Resolver) Query() generated.QueryResolver {
return &queryResolver{r}
}
func (r *queryResolver) Node(ctx context.Context, id int) (ent.Noder, error) {
return r.Client.Noder(ctx, id)
}
func (r *queryResolver) Nodes(ctx context.Context, ids []int) ([]ent.Noder, error) {
return r.Client.Noders(ctx, ids)
}
func (r *queryResolver) Users(
ctx context.Context,
after *ent.Cursor,
first *int,
before *ent.Cursor,
last *int,
orderBy *ent.UserOrder,
where *ent.UserWhereInput,
) (*ent.UserConnection, error) {
return r.Client.User.
Query().
Paginate(ctx, after, first, before, last,
ent.WithUserOrder(orderBy),
ent.WithUserFilter(where.Filter),
)
}
// Mutation resolver
type mutationResolver struct{ *Resolver }
func (r *Resolver) Mutation() generated.MutationResolver {
return &mutationResolver{r}
}
func (r *mutationResolver) CreateUser(ctx context.Context, input ent.CreateUserInput) (*ent.User, error) {
return r.Client.User.Create().SetInput(input).Save(ctx)
}
func (r *mutationResolver) UpdateUser(ctx context.Context, id int, input ent.UpdateUserInput) (*ent.User, error) {
return r.Client.User.UpdateOneID(id).SetInput(input).Save(ctx)
}
func (r *mutationResolver) DeleteUser(ctx context.Context, id int) (*ent.User, error) {
user, err := r.Client.User.Get(ctx, id)
if err != nil {
return nil, err
}
if err := r.Client.User.DeleteOne(user).Exec(ctx); err != nil {
return nil, err
}
return user, nil
}
11.4 启动GraphQL服务器
package main
import (
"log"
"net/http"
"github.com/99designs/gqlgen/graphql/handler"
"github.com/99designs/gqlgen/graphql/playground"
"myapp/ent"
"myapp/graph/generated"
"myapp/graph/resolver"
)
func main() {
client, err := ent.Open("postgres", "...")
if err != nil {
log.Fatal(err)
}
defer client.Close()
// 创建GraphQL服务器
srv := handler.NewDefaultServer(
generated.NewExecutableSchema(resolver.NewSchema(client)),
)
// 添加中间件
srv.Use(entgql.Transactioner{TxOpener: client})
http.Handle("/", playground.Handler("GraphQL playground", "/query"))
http.Handle("/query", srv)
log.Println("GraphQL server running at http://localhost:8080/")
log.Fatal(http.ListenAndServe(":8080", nil))
}
十二、gRPC集成
12.1 定义Proto文件
// proto/user.proto
syntax = "proto3";
package user;
option go_package = "myapp/proto/user";
service UserService {
rpc CreateUser(CreateUserRequest) returns (User);
rpc GetUser(GetUserRequest) returns (User);
rpc ListUsers(ListUsersRequest) returns (ListUsersResponse);
rpc UpdateUser(UpdateUserRequest) returns (User);
rpc DeleteUser(DeleteUserRequest) returns (DeleteUserResponse);
}
message User {
int64 id = 1;
string name = 2;
string email = 3;
string status = 4;
int64 created_at = 5;
}
message CreateUserRequest {
string name = 1;
string email = 2;
}
message GetUserRequest {
int64 id = 1;
}
message ListUsersRequest {
int32 page = 1;
int32 page_size = 2;
string status = 3;
}
message ListUsersResponse {
repeated User users = 1;
int32 total = 2;
int32 page = 3;
int32 pages = 4;
}
message UpdateUserRequest {
int64 id = 1;
optional string name = 2;
optional string email = 3;
optional string status = 4;
}
message DeleteUserRequest {
int64 id = 1;
}
message DeleteUserResponse {
bool success = 1;
}
12.2 实现gRPC服务
package grpc
import (
"context"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"myapp/ent"
"myapp/ent/user"
pb "myapp/proto/user"
)
type UserServiceServer struct {
pb.UnimplementedUserServiceServer
client *ent.Client
}
func NewUserServiceServer(client *ent.Client) *UserServiceServer {
return &UserServiceServer{client: client}
}
// 将Ent User转换为Proto User
func toProtoUser(u *ent.User) *pb.User {
return &pb.User{
Id: int64(u.ID),
Name: u.Name,
Email: u.Email,
Status: u.Status.String(),
CreatedAt: u.CreatedAt.Unix(),
}
}
// CreateUser 创建用户
func (s *UserServiceServer) CreateUser(ctx context.Context, req *pb.CreateUserRequest) (*pb.User, error) {
// 参数验证
if req.Name == "" {
return nil, status.Error(codes.InvalidArgument, "name is required")
}
if req.Email == "" {
return nil, status.Error(codes.InvalidArgument, "email is required")
}
// 创建用户
u, err := s.client.User.
Create().
SetName(req.Name).
SetEmail(req.Email).
SetStatus(user.StatusActive).
Save(ctx)
if err != nil {
if ent.IsConstraintError(err) {
return nil, status.Error(codes.AlreadyExists, "email already exists")
}
return nil, status.Error(codes.Internal, err.Error())
}
return toProtoUser(u), nil
}
// GetUser 获取用户
func (s *UserServiceServer) GetUser(ctx context.Context, req *pb.GetUserRequest) (*pb.User, error) {
u, err := s.client.User.Get(ctx, int(req.Id))
if err != nil {
if ent.IsNotFound(err) {
return nil, status.Error(codes.NotFound, "user not found")
}
return nil, status.Error(codes.Internal, err.Error())
}
return toProtoUser(u), nil
}
// ListUsers 列出用户
func (s *UserServiceServer) ListUsers(ctx context.Context, req *pb.ListUsersRequest) (*pb.ListUsersResponse, error) {
// 设置默认分页参数
page := int(req.Page)
if page < 1 {
page = 1
}
pageSize := int(req.PageSize)
if pageSize < 1 || pageSize > 100 {
pageSize = 10
}
// 构建查询
query := s.client.User.Query()
// 应用过滤条件
if req.Status != "" {
query = query.Where(user.StatusEQ(user.Status(req.Status)))
}
// 获取总数
total, err := query.Clone().Count(ctx)
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}
// 分页查询
users, err := query.
Order(ent.Desc(user.FieldCreatedAt)).
Offset((page - 1) * pageSize).
Limit(pageSize).
All(ctx)
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}
// 转换结果
pbUsers := make([]*pb.User, len(users))
for i, u := range users {
pbUsers[i] = toProtoUser(u)
}
return &pb.ListUsersResponse{
Users: pbUsers,
Total: int32(total),
Page: int32(page),
Pages: int32((total + pageSize - 1) / pageSize),
}, nil
}
// UpdateUser 更新用户
func (s *UserServiceServer) UpdateUser(ctx context.Context, req *pb.UpdateUserRequest) (*pb.User, error) {
// 检查用户是否存在
exists, err := s.client.User.Query().Where(user.IDEQ(int(req.Id))).Exist(ctx)
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}
if !exists {
return nil, status.Error(codes.NotFound, "user not found")
}
// 构建更新
update := s.client.User.UpdateOneID(int(req.Id))
if req.Name != nil {
update = update.SetName(*req.Name)
}
if req.Email != nil {
update = update.SetEmail(*req.Email)
}
if req.Status != nil {
update = update.SetStatus(user.Status(*req.Status))
}
u, err := update.Save(ctx)
if err != nil {
if ent.IsConstraintError(err) {
return nil, status.Error(codes.AlreadyExists, "email already exists")
}
return nil, status.Error(codes.Internal, err.Error())
}
return toProtoUser(u), nil
}
// DeleteUser 删除用户
func (s *UserServiceServer) DeleteUser(ctx context.Context, req *pb.DeleteUserRequest) (*pb.DeleteUserResponse, error) {
err := s.client.User.DeleteOneID(int(req.Id)).Exec(ctx)
if err != nil {
if ent.IsNotFound(err) {
return nil, status.Error(codes.NotFound, "user not found")
}
return nil, status.Error(codes.Internal, err.Error())
}
return &pb.DeleteUserResponse{Success: true}, nil
}
12.3 启动gRPC服务器
package main
import (
"log"
"net"
"google.golang.org/grpc"
"google.golang.org/grpc/reflection"
"myapp/ent"
grpcserver "myapp/grpc"
pb "myapp/proto/user"
)
func main() {
// 连接数据库
client, err := ent.Open("postgres", "...")
if err != nil {
log.Fatalf("failed opening connection to postgres: %v", err)
}
defer client.Close()
// 运行迁移
if err := client.Schema.Create(context.Background()); err != nil {
log.Fatalf("failed creating schema: %v", err)
}
// 创建gRPC服务器
server := grpc.NewServer(
grpc.UnaryInterceptor(loggingInterceptor),
)
// 注册服务
pb.RegisterUserServiceServer(server, grpcserver.NewUserServiceServer(client))
// 注册反射服务(用于grpcurl等工具)
reflection.Register(server)
// 启动服务器
lis, err := net.Listen("tcp", ":50051")
if err != nil {
log.Fatalf("failed to listen: %v", err)
}
log.Println("gRPC server running at :50051")
if err := server.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}
// 日志拦截器
func loggingInterceptor(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
start := time.Now()
resp, err := handler(ctx, req)
log.Printf(
"method=%s duration=%s error=%v",
info.FullMethod,
time.Since(start),
err,
)
return resp, err
}
十三、测试
13.1 单元测试
package ent_test
import (
"context"
"testing"
"myapp/ent"
"myapp/ent/enttest"
"myapp/ent/user"
_ "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestUserCRUD(t *testing.T) {
// 创建测试客户端(使用SQLite内存数据库)
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
defer client.Close()
ctx := context.Background()
// 测试创建
t.Run("Create", func(t *testing.T) {
u, err := client.User.
Create().
SetName("测试用户").
SetEmail("test@example.com").
SetAge(25).
Save(ctx)
require.NoError(t, err)
assert.NotZero(t, u.ID)
assert.Equal(t, "测试用户", u.Name)
assert.Equal(t, "test@example.com", u.Email)
})
// 测试查询
t.Run("Query", func(t *testing.T) {
u, err := client.User.
Query().
Where(user.EmailEQ("test@example.com")).
Only(ctx)
require.NoError(t, err)
assert.Equal(t, "测试用户", u.Name)
})
// 测试更新
t.Run("Update", func(t *testing.T) {
affected, err := client.User.
Update().
Where(user.EmailEQ("test@example.com")).
SetName("更新后的名字").
Save(ctx)
require.NoError(t, err)
assert.Equal(t, 1, affected)
// 验证更新
u, err := client.User.
Query().
Where(user.EmailEQ("test@example.com")).
Only(ctx)
require.NoError(t, err)
assert.Equal(t, "更新后的名字", u.Name)
})
// 测试删除
t.Run("Delete", func(t *testing.T) {
affected, err := client.User.
Delete().
Where(user.EmailEQ("test@example.com")).
Exec(ctx)
require.NoError(t, err)
assert.Equal(t, 1, affected)
// 验证删除
exists, err := client.User.
Query().
Where(user.EmailEQ("test@example.com")).
Exist(ctx)
require.NoError(t, err)
assert.False(t, exists)
})
}
13.2 关系测试
func TestUserPostRelation(t *testing.T) {
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
defer client.Close()
ctx := context.Background()
// 创建用户和文章
t.Run("CreateWithRelation", func(t *testing.T) {
// 创建用户
u, err := client.User.
Create().
SetName("作者").
SetEmail("author@example.com").
Save(ctx)
require.NoError(t, err)
// 创建文章
p1, err := client.Post.
Create().
SetTitle("文章1").
SetContent("内容1").
SetAuthor(u).
Save(ctx)
require.NoError(t, err)
p2, err := client.Post.
Create().
SetTitle("文章2").
SetContent("内容2").
SetAuthorID(u.ID).
Save(ctx)
require.NoError(t, err)
// 验证关系
posts, err := u.QueryPosts().All(ctx)
require.NoError(t, err)
assert.Len(t, posts, 2)
// 验证反向关系
author, err := p1.QueryAuthor().Only(ctx)
require.NoError(t, err)
assert.Equal(t, u.ID, author.ID)
})
// 测试预加载
t.Run("EagerLoading", func(t *testing.T) {
users, err := client.User.
Query().
WithPosts().
All(ctx)
require.NoError(t, err)
assert.NotEmpty(t, users)
for _, u := range users {
// Edges.Posts已经被预加载
assert.NotNil(t, u.Edges.Posts)
}
})
}
13.3 事务测试
func TestTransaction(t *testing.T) {
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
defer client.Close()
ctx := context.Background()
t.Run("CommitTransaction", func(t *testing.T) {
tx, err := client.Tx(ctx)
require.NoError(t, err)
_, err = tx.User.
Create().
SetName("事务用户").
SetEmail("tx@example.com").
Save(ctx)
require.NoError(t, err)
err = tx.Commit()
require.NoError(t, err)
// 验证提交成功
exists, err := client.User.
Query().
Where(user.EmailEQ("tx@example.com")).
Exist(ctx)
require.NoError(t, err)
assert.True(t, exists)
})
t.Run("RollbackTransaction", func(t *testing.T) {
tx, err := client.Tx(ctx)
require.NoError(t, err)
_, err = tx.User.
Create().
SetName("回滚用户").
SetEmail("rollback@example.com").
Save(ctx)
require.NoError(t, err)
err = tx.Rollback()
require.NoError(t, err)
// 验证回滚成功
exists, err := client.User.
Query().
Where(user.EmailEQ("rollback@example.com")).
Exist(ctx)
require.NoError(t, err)
assert.False(t, exists)
})
}
13.4 Mock测试
package service_test
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"myapp/ent"
)
// MockUserRepository 用户仓库Mock
type MockUserRepository struct {
mock.Mock
}
func (m *MockUserRepository) Create(ctx context.Context, name, email string) (*ent.User, error) {
args := m.Called(ctx, name, email)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*ent.User), args.Error(1)
}
func (m *MockUserRepository) GetByID(ctx context.Context, id int) (*ent.User, error) {
args := m.Called(ctx, id)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*ent.User), args.Error(1)
}
func (m *MockUserRepository) GetByEmail(ctx context.Context, email string) (*ent.User, error) {
args := m.Called(ctx, email)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*ent.User), args.Error(1)
}
// UserService 用户服务
type UserService struct {
repo UserRepository
}
type UserRepository interface {
Create(ctx context.Context, name, email string) (*ent.User, error)
GetByID(ctx context.Context, id int) (*ent.User, error)
GetByEmail(ctx context.Context, email string) (*ent.User, error)
}
func NewUserService(repo UserRepository) *UserService {
return &UserService{repo: repo}
}
func (s *UserService) Register(ctx context.Context, name, email string) (*ent.User, error) {
// 检查邮箱是否已存在
existing, err := s.repo.GetByEmail(ctx, email)
if err == nil && existing != nil {
return nil, fmt.Errorf("email already exists")
}
return s.repo.Create(ctx, name, email)
}
// 测试
func TestUserService_Register(t *testing.T) {
ctx := context.Background()
t.Run("Success", func(t *testing.T) {
mockRepo := new(MockUserRepository)
service := NewUserService(mockRepo)
expectedUser := &ent.User{
ID: 1,
Name: "测试用户",
Email: "test@example.com",
}
// 设置期望
mockRepo.On("GetByEmail", ctx, "test@example.com").Return(nil, fmt.Errorf("not found"))
mockRepo.On("Create", ctx, "测试用户", "test@example.com").Return(expectedUser, nil)
// 执行
user, err := service.Register(ctx, "测试用户", "test@example.com")
// 断言
assert.NoError(t, err)
assert.Equal(t, expectedUser, user)
mockRepo.AssertExpectations(t)
})
t.Run("EmailExists", func(t *testing.T) {
mockRepo := new(MockUserRepository)
service := NewUserService(mockRepo)
existingUser := &ent.User{
ID: 1,
Email: "existing@example.com",
}
mockRepo.On("GetByEmail", ctx, "existing@example.com").Return(existingUser, nil)
user, err := service.Register(ctx, "新用户", "existing@example.com")
assert.Error(t, err)
assert.Nil(t, user)
assert.Contains(t, err.Error(), "email already exists")
mockRepo.AssertExpectations(t)
})
}
13.5 Hooks测试
func TestHooks(t *testing.T) {
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
defer client.Close()
ctx := context.Background()
// 记录Hook被调用
var hookCalled bool
// 添加测试Hook
client.User.Use(func(next ent.Mutator) ent.Mutator {
return hook.UserFunc(func(ctx context.Context, m *ent.UserMutation) (ent.Value, error) {
hookCalled = true
return next.Mutate(ctx, m)
})
})
t.Run("HookIsCalled", func(t *testing.T) {
hookCalled = false
_, err := client.User.
Create().
SetName("Hook测试").
SetEmail("hook@example.com").
Save(ctx)
require.NoError(t, err)
assert.True(t, hookCalled, "Hook should be called")
})
}
十四、性能优化
14.1 批量操作优化
package optimization
import (
"context"
"myapp/ent"
)
// 批量创建(使用CreateBulk)
func BulkCreate(ctx context.Context, client *ent.Client, users []UserInput) error {
builders := make([]*ent.UserCreate, len(users))
for i, u := range users {
builders[i] = client.User.
Create().
SetName(u.Name).
SetEmail(u.Email)
}
_, err := client.User.CreateBulk(builders...).Save(ctx)
return err
}
// 批量更新
func BulkUpdate(ctx context.Context, client *ent.Client, ids []int, status string) error {
_, err := client.User.
Update().
Where(user.IDIn(ids...)).
SetStatus(status).
Save(ctx)
return err
}
// 批量删除
func BulkDelete(ctx context.Context, client *ent.Client, ids []int) error {
_, err := client.User.
Delete().
Where(user.IDIn(ids...)).
Exec(ctx)
return err
}
// 分批处理大量数据
func ProcessInBatches(ctx context.Context, client *ent.Client, batchSize int, fn func([]*ent.User) error) error {
var lastID int
for {
users, err := client.User.
Query().
Where(user.IDGT(lastID)).
Order(ent.Asc(user.FieldID)).
Limit(batchSize).
All(ctx)
if err != nil {
return err
}
if len(users) == 0 {
break
}
if err := fn(users); err != nil {
return err
}
lastID = users[len(users)-1].ID
}
return nil
}
14.2 查询优化
package optimization
import (
"context"
"myapp/ent"
"myapp/ent/user"
"myapp/ent/post"
)
// 只选择需要的字段
func SelectSpecificFields(ctx context.Context, client *ent.Client) ([]struct {
ID int
Name string
Email string
}, error) {
var results []struct {
ID int `json:"id"`
Name string `json:"name"`
Email string `json:"email"`
}
err := client.User.
Query().
Select(user.FieldID, user.FieldName, user.FieldEmail).
Scan(ctx, &results)
return results, err
}
// 避免N+1问题 - 使用预加载
func EagerLoadingExample(ctx context.Context, client *ent.Client) ([]*ent.User, error) {
return client.User.
Query().
WithPosts(func(q *ent.PostQuery) {
q.Where(post.StatusEQ("published"))
q.Order(ent.Desc(post.FieldCreatedAt))
q.Limit(5)
q.WithTags() // 嵌套预加载
}).
WithProfile().
All(ctx)
}
// 使用索引优化查询
func QueryWithIndex(ctx context.Context, client *ent.Client, email string) (*ent.User, error) {
// 确保email字段有索引
return client.User.
Query().
Where(user.EmailEQ(email)).
Only(ctx)
}
// 只获取计数,不加载实体
func CountOnly(ctx context.Context, client *ent.Client) (int, error) {
return client.User.
Query().
Where(user.StatusEQ("active")).
Count(ctx)
}
// 只检查存在性
func ExistsOnly(ctx context.Context, client *ent.Client, email string) (bool, error) {
return client.User.
Query().
Where(user.EmailEQ(email)).
Exist(ctx)
}
// 使用Only代替First(当确定只有一个结果时)
func GetSingleResult(ctx context.Context, client *ent.Client, id int) (*ent.User, error) {
return client.User.
Query().
Where(user.IDEQ(id)).
Only(ctx) // 如果有多个结果会返回错误
}
14.3 连接池配置
package main
import (
"database/sql"
"time"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
"myapp/ent"
_ "github.com/lib/pq"
)
func NewClientWithPool() (*ent.Client, error) {
// 打开数据库连接
db, err := sql.Open("postgres", "postgres://localhost:5432/myapp?sslmode=disable")
if err != nil {
return nil, err
}
// 配置连接池
db.SetMaxOpenConns(100) // 最大打开连接数
db.SetMaxIdleConns(10) // 最大空闲连接数
db.SetConnMaxLifetime(time.Hour) // 连接最大存活时间
db.SetConnMaxIdleTime(time.Minute * 30) // 空闲连接最大存活时间
// 创建Ent驱动
drv := entsql.OpenDB(dialect.Postgres, db)
// 创建Ent客户端
client := ent.NewClient(ent.Driver(drv))
return client, nil
}
14.4 缓存策略
package cache
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/go-redis/redis/v8"
"myapp/ent"
)
type CachedUserRepository struct {
client *ent.Client
redis *redis.Client
ttl time.Duration
}
func NewCachedUserRepository(client *ent.Client, redis *redis.Client) *CachedUserRepository {
return &CachedUserRepository{
client: client,
redis: redis,
ttl: time.Minute * 15,
}
}
func (r *CachedUserRepository) cacheKey(id int) string {
return fmt.Sprintf("user:%d", id)
}
// GetByID 带缓存的获取
func (r *CachedUserRepository) GetByID(ctx context.Context, id int) (*ent.User, error) {
key := r.cacheKey(id)
// 尝试从缓存获取
cached, err := r.redis.Get(ctx, key).Result()
if err == nil {
var user ent.User
if err := json.Unmarshal([]byte(cached), &user); err == nil {
return &user, nil
}
}
// 从数据库获取
user, err := r.client.User.Get(ctx, id)
if err != nil {
return nil, err
}
// 写入缓存
data, _ := json.Marshal(user)
r.redis.Set(ctx, key, data, r.ttl)
return user, nil
}
// Update 更新并清除缓存
func (r *CachedUserRepository) Update(ctx context.Context, id int, name string) (*ent.User, error) {
user, err := r.client.User.
UpdateOneID(id).
SetName(name).
Save(ctx)
if err != nil {
return nil, err
}
// 清除缓存
r.redis.Del(ctx, r.cacheKey(id))
return user, nil
}
// Delete 删除并清除缓存
func (r *CachedUserRepository) Delete(ctx context.Context, id int) error {
err := r.client.User.DeleteOneID(id).Exec(ctx)
if err != nil {
return err
}
// 清除缓存
r.redis.Del(ctx, r.cacheKey(id))
return nil
}
// 使用缓存拦截器
func CacheInterceptor(redis *redis.Client, ttl time.Duration) ent.Interceptor {
return intercept.TraverseFunc(func(ctx context.Context, q intercept.Query) error {
// 可以在这里实现查询级别的缓存
return nil
})
}
十五、多数据库支持
15.1 读写分离
package database
import (
"context"
"database/sql"
"sync/atomic"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
"myapp/ent"
)
// ReadWriteDriver 读写分离驱动
type ReadWriteDriver struct {
write dialect.Driver
reads []dialect.Driver
counter uint64
}
func NewReadWriteDriver(writeDB *sql.DB, readDBs ...*sql.DB) *ReadWriteDriver {
reads := make([]dialect.Driver, len(readDBs))
for i, db := range readDBs {
reads[i] = entsql.OpenDB(dialect.Postgres, db)
}
return &ReadWriteDriver{
write: entsql.OpenDB(dialect.Postgres, writeDB),
reads: reads,
}
}
// 实现dialect.Driver接口
func (d *ReadWriteDriver) Dialect() string {
return d.write.Dialect()
}
func (d *ReadWriteDriver) Close() error {
for _, read := range d.reads {
read.Close()
}
return d.write.Close()
}
// Exec 写操作使用主库
func (d *ReadWriteDriver) Exec(ctx context.Context, query string, args, v interface{}) error {
return d.write.Exec(ctx, query, args, v)
}
// Query 读操作使用从库(轮询)
func (d *ReadWriteDriver) Query(ctx context.Context, query string, args, v interface{}) error {
// 检查是否强制使用主库
if useMaster(ctx) {
return d.write.Query(ctx, query, args, v)
}
// 轮询选择从库
idx := atomic.AddUint64(&d.counter, 1) % uint64(len(d.reads))
return d.reads[idx].Query(ctx, query, args, v)
}
// Tx 事务使用主库
func (d *ReadWriteDriver) Tx(ctx context.Context) (dialect.Tx, error) {
return d.write.Tx(ctx)
}
// BeginTx 开始事务
func (d *ReadWriteDriver) BeginTx(ctx context.Context, opts *sql.TxOptions) (dialect.Tx, error) {
return d.write.(interface {
BeginTx(context.Context, *sql.TxOptions) (dialect.Tx, error)
}).BeginTx(ctx, opts)
}
// Context key for forcing master
type masterKey struct{}
func useMaster(ctx context.Context) bool {
v, _ := ctx.Value(masterKey{}).(bool)
return v
}
// WithMaster 强制使用主库
func WithMaster(ctx context.Context) context.Context {
return context.WithValue(ctx, masterKey{}, true)
}
// 使用示例
func SetupReadWriteClient() (*ent.Client, error) {
// 主库
masterDB, err := sql.Open("postgres", "postgres://master:5432/myapp")
if err != nil {
return nil, err
}
// 从库
slave1DB, err := sql.Open("postgres", "postgres://slave1:5432/myapp")
if err != nil {
return nil, err
}
slave2DB, err := sql.Open("postgres", "postgres://slave2:5432/myapp")
if err != nil {
return nil, err
}
// 创建读写分离驱动
drv := NewReadWriteDriver(masterDB, slave1DB, slave2DB)
return ent.NewClient(ent.Driver(drv)), nil
}
// 业务代码使用
func QueryExample(ctx context.Context, client *ent.Client) {
// 普通查询 - 使用从库
users, _ := client.User.Query().All(ctx)
// 强制使用主库查询(例如刚写入后立即读取)
user, _ := client.User.Query().
Where(user.IDEQ(1)).
Only(WithMaster(ctx))
}
15.2 多租户数据库
package multitenancy
import (
"context"
"fmt"
"sync"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
"myapp/ent"
)
// TenantManager 租户管理器
type TenantManager struct {
clients map[string]*ent.Client
mu sync.RWMutex
config DatabaseConfig
}
type DatabaseConfig struct {
Host string
Port int
User string
Password string
}
func NewTenantManager(config DatabaseConfig) *TenantManager {
return &TenantManager{
clients: make(map[string]*ent.Client),
config: config,
}
}
// GetClient 获取租户的数据库客户端
func (m *TenantManager) GetClient(tenantID string) (*ent.Client, error) {
m.mu.RLock()
client, exists := m.clients[tenantID]
m.mu.RUnlock()
if exists {
return client, nil
}
// 创建新连接
m.mu.Lock()
defer m.mu.Unlock()
// 双重检查
if client, exists := m.clients[tenantID]; exists {
return client, nil
}
// 每个租户使用独立的数据库
dsn := fmt.Sprintf(
"postgres://%s:%s@%s:%d/tenant_%s?sslmode=disable",
m.config.User,
m.config.Password,
m.config.Host,
m.config.Port,
tenantID,
)
client, err := ent.Open("postgres", dsn)
if err != nil {
return nil, err
}
// 自动迁移
if err := client.Schema.Create(context.Background()); err != nil {
client.Close()
return nil, err
}
m.clients[tenantID] = client
return client, nil
}
// CloseAll 关闭所有连接
func (m *TenantManager) CloseAll() {
m.mu.Lock()
defer m.mu.Unlock()
for _, client := range m.clients {
client.Close()
}
}
// 中间件:从请求中获取租户并注入客户端
func TenantMiddleware(manager *TenantManager) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tenantID := r.Header.Get("X-Tenant-ID")
if tenantID == "" {
http.Error(w, "Tenant ID required", http.StatusBadRequest)
return
}
client, err := manager.GetClient(tenantID)
if err != nil {
http.Error(w, "Failed to get tenant database", http.StatusInternalServerError)
return
}
ctx := context.WithValue(r.Context(), "ent_client", client)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// 从context获取客户端
func ClientFromContext(ctx context.Context) *ent.Client {
return ctx.Value("ent_client").(*ent.Client)
}
15.3 Schema级别多租户
package schema
import (
"context"
"fmt"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"myapp/ent"
)
// SchemaDriver 支持Schema切换的驱动
type SchemaDriver struct {
dialect.Driver
schema string
}
func NewSchemaDriver(drv dialect.Driver, schema string) *SchemaDriver {
return &SchemaDriver{
Driver: drv,
schema: schema,
}
}
func (d *SchemaDriver) Exec(ctx context.Context, query string, args, v interface{}) error {
// 设置search_path
setSchema := fmt.Sprintf("SET search_path TO %s", d.schema)
if err := d.Driver.Exec(ctx, setSchema, []interface{}{}, nil); err != nil {
return err
}
return d.Driver.Exec(ctx, query, args, v)
}
func (d *SchemaDriver) Query(ctx context.Context, query string, args, v interface{}) error {
setSchema := fmt.Sprintf("SET search_path TO %s", d.schema)
if err := d.Driver.Exec(ctx, setSchema, []interface{}{}, nil); err != nil {
return err
}
return d.Driver.Query(ctx, query, args, v)
}
// 为租户创建Schema
func CreateTenantSchema(ctx context.Context, db *sql.DB, tenantID string) error {
schemaName := fmt.Sprintf("tenant_%s", tenantID)
_, err := db.ExecContext(ctx, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schemaName))
return err
}
// 获取租户客户端
func GetTenantClient(baseClient *ent.Client, tenantID string) *ent.Client {
schemaName := fmt.Sprintf("tenant_%s", tenantID)
drv := NewSchemaDriver(baseClient.Driver(), schemaName)
return ent.NewClient(ent.Driver(drv))
}
十六、完整项目结构
myapp/
├── cmd/
│ └── server/
│ └── main.go # 应用入口
├── ent/
│ ├── schema/ # Schema定义
│ │ ├── user.go
│ │ ├── post.go
│ │ ├── comment.go
│ │ └── mixin/
│ │ ├── time.go
│ │ └── soft_delete.go
│ ├── generate.go # 代码生成配置
│ ├── entc.go # Ent配置
│ ├── client.go # 生成的客户端
│ ├── user.go # 生成的User实体
│ ├── post.go # 生成的Post实体
│ ├── migrate/
│ │ ├── main.go # 迁移脚本
│ │ └── migrations/ # 版本化迁移文件
│ ├── hook/ # 生成的Hooks
│ ├── intercept/ # 生成的拦截器
│ └── privacy/ # 生成的隐私规则
├── internal/
│ ├── config/
│ │ └── config.go # 配置管理
│ ├── database/
│ │ └── database.go # 数据库连接
│ ├── repository/ # 数据访问层
│ │ ├── user.go
│ │ └── post.go
│ ├── service/ # 业务逻辑层
│ │ ├── user.go
│ │ └── post.go
│ └── handler/ # HTTP处理器
│ ├── user.go
│ └── post.go
├── pkg/
│ ├── pagination/
│ │ └── pagination.go # 分页工具
│ └── validator/
│ └── validator.go # 验证工具
├── api/
│ └── proto/ # gRPC Proto文件
│ └── user.proto
├── graph/ # GraphQL
│ ├── resolver/
│ │ └── resolver.go
│ └── generated/
├── go.mod
├── go.sum
└── Makefile
16.1 Makefile示例
.PHONY: generate migrate test run
# 生成Ent代码
generate:
go generate ./ent
# 创建新的Schema
new-schema:
go run -mod=mod entgo.io/ent/cmd/ent new $(name)
# 生成迁移文件
migrate-diff:
go run -mod=mod ent/migrate/main.go $(name)
# 应用迁移
migrate-apply:
atlas migrate apply \
--dir "file://ent/migrate/migrations" \
--url "$(DATABASE_URL)"
# 运行测试
test:
go test -v ./...
# 运行测试(带覆盖率)
test-coverage:
go test -v -coverprofile=coverage.out ./...
go tool cover -html=coverage.out
# 运行应用
run:
go run cmd/server/main.go
# 格式化代码
fmt:
go fmt ./...
goimports -w .
# 代码检查
lint:
golangci-lint run
# 构建
build:
go build -o bin/server cmd/server/main.go
# 生成Proto
proto:
protoc --go_out=. --go-grpc_out=. api/proto/*.proto
16.2 完整的Main函数示例
// cmd/server/main.go
package main
import (
"context"
"log"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"myapp/ent"
"myapp/internal/config"
"myapp/internal/database"
"myapp/internal/handler"
"myapp/internal/repository"
"myapp/internal/service"
)
func main() {
// 加载配置
cfg, err := config.Load()
if err != nil {
log.Fatalf("Failed to load config: %v", err)
}
// 连接数据库
client, err := database.NewClient(cfg.Database)
if err != nil {
log.Fatalf("Failed to connect database: %v", err)
}
defer client.Close()
// 运行迁移
if err := client.Schema.Create(context.Background()); err != nil {
log.Fatalf("Failed to run migrations: %v", err)
}
// 初始化层
userRepo := repository.NewUserRepository(client)
userService := service.NewUserService(userRepo)
userHandler := handler.NewUserHandler(userService)
// 设置路由
r := chi.NewRouter()
// 中间件
r.Use(middleware.Logger)
r.Use(middleware.Recoverer)
r.Use(middleware.Timeout(60 * time.Second))
// 路由
r.Route("/api/v1", func(r chi.Router) {
r.Route("/users", func(r chi.Router) {
r.Get("/", userHandler.List)
r.Post("/", userHandler.Create)
r.Get("/{id}", userHandler.Get)
r.Put("/{id}", userHandler.Update)
r.Delete("/{id}", userHandler.Delete)
})
})
// 健康检查
r.Get("/health", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
})
// 创建服务器
server := &http.Server{
Addr: ":" + cfg.Server.Port,
Handler: r,
ReadTimeout: 15 * time.Second,
WriteTimeout: 15 * time.Second,
IdleTimeout: 60 * time.Second,
}
// 优雅关闭
go func() {
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
<-sigChan
log.Println("Shutting down server...")
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := server.Shutdown(ctx); err != nil {
log.Printf("Server shutdown error: %v", err)
}
client.Close()
log.Println("Server stopped")
}()
// 启动服务器
log.Printf("Server running at %s", server.Addr)
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("Server error: %v", err)
}
}
十七、最佳实践
17.1 Schema设计最佳实践
使用Mixin复用通用字段
// 创建基础Mixin
type BaseMixin struct {
mixin.Schema
}
func (BaseMixin) Fields() []ent.Field {
return []ent.Field{
field.Time("created_at").
Immutable().
Default(time.Now),
field.Time("updated_at").
Default(time.Now).
UpdateDefault(time.Now),
}
}
// 在所有实体中使用
type User struct {
ent.Schema
}
func (User) Mixin() []ent.Mixin {
return []ent.Mixin{BaseMixin{}}
}
合理使用索引
// 为经常查询的字段添加索引
func (User) Indexes() []ent.Index {
return []ent.Index{
index.Fields("email").Unique(),
index.Fields("status", "created_at"),
index.Fields("tenant_id", "email").Unique(),
}
}
字段验证
func (User) Fields() []ent.Field {
return []ent.Field{
// 使用内置验证
field.String("email").
NotEmpty().
Match(regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)),
// 自定义验证
field.Int("age").
Validate(func(age int) error {
if age < 0 || age > 150 {
return errors.New("age must be between 0 and 150")
}
return nil
}),
}
}
17.2 查询优化最佳实践
避免N+1问题
// ❌ 不好的做法
users, _ := client.User.Query().All(ctx)
for _, u := range users {
posts, _ := u.QueryPosts().All(ctx) // N+1查询
}
// ✅ 好的做法
users, _ := client.User.Query().
WithPosts().
All(ctx)
for _, u := range users {
posts := u.Edges.Posts // 已预加载
}
只查询需要的字段
// ❌ 不好的做法 - 查询所有字段
users, _ := client.User.Query().All(ctx)
// ✅ 好的做法 - 只查询需要的字段
var results []struct {
ID int
Name string
}
client.User.Query().
Select(user.FieldID, user.FieldName).
Scan(ctx, &results)
使用分页
// 始终对大量数据使用分页
users, _ := client.User.Query().
Offset((page - 1) * pageSize).
Limit(pageSize).
All(ctx)
17.3 事务使用最佳实践
保持事务简短
// ✅ 好的做法 - 事务只包含必要的数据库操作
err := WithTx(ctx, client, func(tx *ent.Tx) error {
user, err := tx.User.Create().SetName("用户").Save(ctx)
if err != nil {
return err
}
_, err = tx.Profile.Create().SetUserID(user.ID).Save(ctx)
return err
})
// ❌ 不好的做法 - 事务包含非数据库操作
err := WithTx(ctx, client, func(tx *ent.Tx) error {
user, err := tx.User.Create().SetName("用户").Save(ctx)
if err != nil {
return err
}
// 不应该在事务中执行外部API调用
sendEmail(user.Email) // ❌
return nil
})
17.4 错误处理最佳实践
// 使用Ent提供的错误检查函数
user, err := client.User.Get(ctx, id)
if err != nil {
if ent.IsNotFound(err) {
// 处理未找到错误
return nil, fmt.Errorf("user not found")
}
if ent.IsConstraintError(err) {
// 处理约束错误(如唯一索引冲突)
return nil, fmt.Errorf("user already exists")
}
// 其他错误
return nil, err
}
17.5 性能监控
// 添加查询性能监控Hook
client.Use(func(next ent.Mutator) ent.Mutator {
return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) {
start := time.Now()
v, err := next.Mutate(ctx, m)
duration := time.Since(start)
// 记录慢查询
if duration > 100*time.Millisecond {
log.Printf("Slow query detected: %s took %v", m.Type(), duration)
}
return v, err
})
})
十八、常见问题
18.1 如何处理软删除?
参见第 10.5 节的软删除实现。
18.2 如何实现乐观锁?
参见第 10.6 节的乐观锁实现。
18.3 如何实现多租户?
参见第 15.2 节的多租户数据库实现。
18.4 如何处理大量数据的迁移?
// 分批处理
func MigrateLargeData(ctx context.Context, client *ent.Client) error {
batchSize := 1000
var lastID int
for {
users, err := client.User.Query().
Where(user.IDGT(lastID)).
Order(ent.Asc(user.FieldID)).
Limit(batchSize).
All(ctx)
if err != nil {
return err
}
if len(users) == 0 {
break
}
// 处理批次数据
for _, u := range users {
// 执行迁移逻辑
}
lastID = users[len(users)-1].ID
}
return nil
}
18.5 如何调试生成的SQL?
// 启用SQL日志
import (
"entgo.io/ent/dialect/sql"
entsql "entgo.io/ent/dialect/sql"
)
// 创建带日志的驱动
drv := entsql.OpenDB(dialect.Postgres, db)
client := ent.NewClient(ent.Driver(entsql.Driver(drv)))
client.Debug() // 启用调试模式,会打印所有SQL查询
18.6 如何处理JSON字段?
type Metadata struct {
Tags []string `json:"tags"`
Settings map[string]string `json:"settings"`
}
// Schema定义
func (User) Fields() []ent.Field {
return []ent.Field{
field.JSON("metadata", &Metadata{}).
Optional(),
}
}
// 使用
user, _ := client.User.Create().
SetMetadata(&Metadata{
Tags: []string{"tag1", "tag2"},
Settings: map[string]string{"theme": "dark"},
}).
Save(ctx)
// 查询
metadata := user.Metadata
tags := metadata.Tags
十九、总结
Ent ORM 是一个功能强大、类型安全的 Go 语言 ORM 框架,具有以下特点:
核心优势
- 类型安全:100% 静态类型和显式 API
- 代码生成:通过代码生成避免运行时反射,性能更好
- 图遍历:轻松处理复杂的关系查询
- 扩展性强:支持 Hooks、Privacy、Interceptors 等多种扩展机制
- 现代化:支持 GraphQL、gRPC 等现代开发技术
适用场景
- 大型应用和微服务
- 需要复杂数据模型和关系的项目
- 对类型安全有严格要求的项目
- 需要代码优先(Code First)方法的项目
学习路径建议
- 从基础的 Schema 定义和 CRUD 操作开始
- 掌握关系定义和查询
- 学习 Hooks 和 Interceptors 进行扩展
- 了解高级特性如事务、Privacy、乐观锁等
- 根据项目需求集成 GraphQL 或 gRPC
- 学习性能优化和最佳实践
参考资源
希望本指南能帮助你快速掌握 Ent ORM 的使用!