164 lines
5.0 KiB
Go
Raw Normal View History

2024-07-30 23:39:55 +07:00
package transactions
import (
"context"
"furtuna-be/internal/common/mycontext"
"furtuna-be/internal/entity"
"go.uber.org/zap"
"gorm.io/gorm"
)
// TransactionRepository provides methods to perform CRUD operations on transactions.
type TransactionRepository struct {
db *gorm.DB
}
func NewTransactionRepository(db *gorm.DB) *TransactionRepository {
return &TransactionRepository{
db: db,
}
}
// Create creates a new transaction in the database.
func (r *TransactionRepository) Create(ctx context.Context, transaction *entity.Transaction) (*entity.Transaction, error) {
if err := r.db.WithContext(ctx).Create(transaction).Error; err != nil {
zap.L().Error("error when creating transaction", zap.Error(err))
return nil, err
}
return r.FindByID(ctx, transaction.ID)
}
// Update updates an existing transaction in the database.
func (r *TransactionRepository) Update(ctx context.Context, transaction *entity.Transaction) (*entity.Transaction, error) {
if err := r.db.WithContext(ctx).Save(transaction).Error; err != nil {
zap.L().Error("error when updating transaction", zap.Error(err))
return nil, err
}
return transaction, nil
}
func (r *TransactionRepository) FindByID(ctx context.Context, id string) (*entity.Transaction, error) {
var transaction entity.Transaction
if err := r.db.WithContext(ctx).First(&transaction, "id = ?", id).Error; err != nil {
zap.L().Error("error when finding transaction by ID", zap.Error(err))
return nil, err
}
return &transaction, nil
}
func (r *TransactionRepository) Delete(ctx context.Context, id string) error {
if err := r.db.WithContext(ctx).Delete(&entity.Transaction{}, "id = ?", id).Error; err != nil {
zap.L().Error("error when deleting transaction", zap.Error(err))
return err
}
return nil
}
func (r *TransactionRepository) FindByPartnerID(ctx context.Context, partnerID int64) ([]entity.Transaction, error) {
var transactions []entity.Transaction
if err := r.db.WithContext(ctx).Where("partner_id = ?", partnerID).Find(&transactions).Error; err != nil {
zap.L().Error("error when finding transactions by partner ID", zap.Error(err))
return nil, err
}
return transactions, nil
}
func (r *TransactionRepository) FindByStatus(ctx context.Context, status string) ([]entity.Transaction, error) {
var transactions []entity.Transaction
if err := r.db.WithContext(ctx).Where("status = ?", status).Find(&transactions).Error; err != nil {
zap.L().Error("error when finding transactions by status", zap.Error(err))
return nil, err
}
return transactions, nil
}
// UpdateStatus updates the status of a transaction by its ID.
func (r *TransactionRepository) UpdateStatus(ctx context.Context, id string, status string) (*entity.Transaction, error) {
transaction, err := r.FindByID(ctx, id)
if err != nil {
return nil, err
}
transaction.Status = status
if err := r.db.WithContext(ctx).Save(transaction).Error; err != nil {
zap.L().Error("error when updating transaction status", zap.Error(err))
return nil, err
}
return transaction, nil
}
// ListTransactions retrieves a list of transactions with optional filters for pagination and sorting.
func (r *TransactionRepository) ListTransactions(ctx context.Context, offset int, limit int, status string, transactionType string) ([]entity.Transaction, int64, error) {
var transactions []entity.Transaction
var total int64
query := r.db.WithContext(ctx).Model(&entity.Transaction{}).Order("created_at DESC")
if status != "" {
query = query.Where("status = ?", status)
}
if transactionType != "" {
query = query.Where("transaction_type = ?", transactionType)
}
if err := query.Count(&total).Error; err != nil {
zap.L().Error("error when counting transactions", zap.Error(err))
return nil, 0, err
}
if offset >= 0 {
query = query.Offset(offset)
}
if limit > 0 {
query = query.Limit(limit)
}
if err := query.Find(&transactions).Error; err != nil {
zap.L().Error("error when listing transactions", zap.Error(err))
return nil, 0, err
}
return transactions, total, nil
}
func (r *TransactionRepository) GetTransactionList(ctx mycontext.Context, req entity.TransactionSearch) ([]*entity.TransactionList, int, error) {
var transactions []*entity.TransactionList
var total int64
query := r.db.Table("transaction t").
Select("t.id, t.transaction_type, t.status, t.created_at, s.name as site_name, p.name as partner_name, t.amount").
Joins("left join sites s on t.site_id = s.id").
Joins("left join partners p on t.partner_id = p.id").
Where("t.partner_id = ?", req.PartnerID)
if req.SiteID != nil {
query = query.Where("t.site_id = ?", req.SiteID)
}
if req.Type != "" {
query = query.Where("t.transaction_type = ?", req.Type)
}
if req.Status != "" {
query = query.Where("t.status = ?", req.Status)
}
if req.Date != "" {
query = query.Where("DATE(t.created_at) = ?", req.Date)
}
query = query.Count(&total)
if req.Offset > 0 {
query = query.Offset(req.Offset)
}
if req.Limit > 0 {
query = query.Limit(req.Limit)
}
err := query.Find(&transactions).Error
if err != nil {
return nil, 0, err
}
return transactions, int(total), nil
}