diff --git a/database/category_model.go b/database/category_model.go index fe20a57..25fdf05 100644 --- a/database/category_model.go +++ b/database/category_model.go @@ -8,6 +8,7 @@ type Category struct { ID string `gorm:"primaryKey" json:"id"` Name string `gorm:"not null" json:"name"` Code string `gorm:"not null;unique" json:"code"` + Sequence int `gorm:"default:null" json:"sequence"` CreatedAt time.Time `gorm:"default:CURRENT_TIMESTAMP" json:"created_at"` UpdatedAt time.Time `gorm:"default:CURRENT_TIMESTAMP" json:"updated_at"` DeletedAt time.Time `gorm:"default:null" json:"deleted_at"` diff --git a/go.mod b/go.mod index c0f0245..1db59bb 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( go.uber.org/fx v1.23.0 golang.org/x/crypto v0.34.0 gopkg.in/yaml.v3 v3.0.1 - gorm.io/driver/postgres v1.4.7 + gorm.io/driver/postgres v1.5.7 ) require ( @@ -20,7 +20,7 @@ require ( github.com/hashicorp/hcl v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect - github.com/jackc/pgx/v5 v5.2.0 // indirect + github.com/jackc/pgx/v5 v5.4.3 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect diff --git a/go.sum b/go.sum index 7cc3ae0..f576ccc 100644 --- a/go.sum +++ b/go.sum @@ -6,7 +6,6 @@ github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= @@ -41,16 +40,12 @@ github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= -github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.2.0 h1:NdPpngX0Y6z6XDFKqmFQaE+bCtkqzvQIOt1wvBlAqs8= -github.com/jackc/pgx/v5 v5.2.0/go.mod h1:Ptn7zmohNsWEsdxRawMzk3gaKma2obW+NWTnKa0S4nk= -github.com/jackc/puddle/v2 v2.1.2/go.mod h1:2lpufsF5mRHO6SuZkm0fNYxM6SWHfvyFj62KwNzgels= +github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY= +github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= @@ -59,13 +54,8 @@ github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGw github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= -github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= @@ -81,7 +71,6 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRI github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/redis/go-redis/v9 v9.7.1 h1:4LhKRCIduqXqtvCUlaq9c8bdHOkICjDMrr1+Zb3osAc= github.com/redis/go-redis/v9 v9.7.1/go.mod h1:f6zhXITC7JUJIlPEiBOTXxJgPLdZcA93GewI7inzyWw= -github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ= @@ -105,7 +94,6 @@ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSS github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= @@ -114,8 +102,6 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/dig v1.18.0 h1:imUL1UiY0Mg4bqbFfsRQO5G4CGRBec/ZujWTvSVp3pw= go.uber.org/dig v1.18.0/go.mod h1:Us0rSJiThwCv2GteUN0Q7OKvU7n5J4dxZ9JKUXozFdE= go.uber.org/fx v1.23.0 h1:lIr/gYWQGfTwGcSXWXu4vP5Ws6iqnNEIY+F/aFzCKTg= @@ -126,67 +112,30 @@ go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/zap v1.26.0 h1:sI7k6L95XOKS281NhVKOFCUNIvv9e0w4BF8N3u+tCRo= go.uber.org/zap v1.26.0/go.mod h1:dtElttAiwGvoJ/vj4IwHBS/gXsEu/pZ50mUIRWuG0so= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.4.0/go.mod h1:3quD/ATkf6oY+rnes5c3ExXTbLc8mueNue5/DoinL80= golang.org/x/crypto v0.34.0 h1:+/C6tk6rf/+t5DhUketUbD1aNGqiSX3j15Z6xuIDlBA= golang.org/x/crypto v0.34.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ= golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.3.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE= golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220923202941-7f9b1623fab7/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= -golang.org/x/text v0.5.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= -gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gorm.io/driver/postgres v1.4.7 h1:J06jXZCNq7Pdf7LIPn8tZn9LsWjd81BRSKveKNr0ZfA= -gorm.io/driver/postgres v1.4.7/go.mod h1:UJChCNLFKeBqQRE+HrkFUbKbq9idPXmTOk2u4Wok8S4= -gorm.io/gorm v1.24.2/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA= +gorm.io/driver/postgres v1.5.7 h1:8ptbNJTDbEmhdr62uReG5BGkdQyeasu/FZHxI0IMGnM= +gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4chKpA= gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8= gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ= diff --git a/vendor/github.com/jackc/pgx/v5/.gitignore b/vendor/github.com/jackc/pgx/v5/.gitignore index 348e014..a2ebbe9 100644 --- a/vendor/github.com/jackc/pgx/v5/.gitignore +++ b/vendor/github.com/jackc/pgx/v5/.gitignore @@ -23,3 +23,5 @@ _testmain.go .envrc /.testdb + +.DS_Store diff --git a/vendor/github.com/jackc/pgx/v5/CHANGELOG.md b/vendor/github.com/jackc/pgx/v5/CHANGELOG.md index 422c807..fb2304a 100644 --- a/vendor/github.com/jackc/pgx/v5/CHANGELOG.md +++ b/vendor/github.com/jackc/pgx/v5/CHANGELOG.md @@ -1,3 +1,83 @@ +# 5.4.3 (August 5, 2023) + +* Fix: QCharArrayOID was defined with the wrong OID (Christoph Engelbert) +* Fix: connect_timeout for sslmode=allow|prefer (smaher-edb) +* Fix: pgxpool: background health check cannot overflow pool +* Fix: Check for nil in defer when sending batch (recover properly from panic) +* Fix: json scan of non-string pointer to pointer +* Fix: zeronull.Timestamptz should use pgtype.Timestamptz +* Fix: NewConnsCount was not correctly counting connections created by Acquire directly. (James Hartig) +* RowTo(AddrOf)StructByPos ignores fields with "-" db tag +* Optimization: improve text format numeric parsing (horpto) + +# 5.4.2 (July 11, 2023) + +* Fix: RowScanner errors are fatal to Rows +* Fix: Enable failover efforts when pg_hba.conf disallows non-ssl connections (Brandon Kauffman) +* Hstore text codec internal improvements (Evan Jones) +* Fix: Stop timers for background reader when not in use. Fixes memory leak when closing connections (Adrian-Stefan Mares) +* Fix: Stop background reader as soon as possible. +* Add PgConn.SyncConn(). This combined with the above fix makes it safe to directly use the underlying net.Conn. + +# 5.4.1 (June 18, 2023) + +* Fix: concurrency bug with pgtypeDefaultMap and simple protocol (Lev Zakharov) +* Add TxOptions.BeginQuery to allow overriding the default BEGIN query + +# 5.4.0 (June 14, 2023) + +* Replace platform specific syscalls for non-blocking IO with more traditional goroutines and deadlines. This returns to the v4 approach with some additional improvements and fixes. This restores the ability to use a pgx.Conn over an ssh.Conn as well as other non-TCP or Unix socket connections. In addition, it is a significantly simpler implementation that is less likely to have cross platform issues. +* Optimization: The default type registrations are now shared among all connections. This saves about 100KB of memory per connection. `pgtype.Type` and `pgtype.Codec` values are now required to be immutable after registration. This was already necessary in most cases but wasn't documented until now. (Lev Zakharov) +* Fix: Ensure pgxpool.Pool.QueryRow.Scan releases connection on panic +* CancelRequest: don't try to read the reply (Nicola Murino) +* Fix: correctly handle bool type aliases (Wichert Akkerman) +* Fix: pgconn.CancelRequest: Fix unix sockets: don't use RemoteAddr() +* Fix: pgx.Conn memory leak with prepared statement caching (Evan Jones) +* Add BeforeClose to pgxpool.Pool (Evan Cordell) +* Fix: various hstore fixes and optimizations (Evan Jones) +* Fix: RowToStructByPos with embedded unexported struct +* Support different bool string representations (Lev Zakharov) +* Fix: error when using BatchResults.Exec on a select that returns an error after some rows. +* Fix: pipelineBatchResults.Exec() not returning error from ResultReader +* Fix: pipeline batch results not closing pipeline when error occurs while reading directly from results instead of using + a callback. +* Fix: scanning a table type into a struct +* Fix: scan array of record to pointer to slice of struct +* Fix: handle null for json (Cemre Mengu) +* Batch Query callback is called even when there is an error +* Add RowTo(AddrOf)StructByNameLax (Audi P. Risa P) + +# 5.3.1 (February 27, 2023) + +* Fix: Support v4 and v5 stdlib in same program (Tomáš Procházka) +* Fix: sql.Scanner not being used in certain cases +* Add text format jsonpath support +* Fix: fake non-blocking read adaptive wait time + +# 5.3.0 (February 11, 2023) + +* Fix: json values work with sql.Scanner +* Fixed / improved error messages (Mark Chambers and Yevgeny Pats) +* Fix: support scan into single dimensional arrays +* Fix: MaxConnLifetimeJitter setting actually jitter (Ben Weintraub) +* Fix: driver.Value representation of bytea should be []byte not string +* Fix: better handling of unregistered OIDs +* CopyFrom can use query cache to avoid extra round trip to get OIDs (Alejandro Do Nascimento Mora) +* Fix: encode to json ignoring driver.Valuer +* Support sql.Scanner on renamed base type +* Fix: pgtype.Numeric text encoding of negative numbers (Mark Chambers) +* Fix: connect with multiple hostnames when one can't be resolved +* Upgrade puddle to remove dependency on uber/atomic and fix alignment issue on 32-bit platform +* Fix: scanning json column into **string +* Multiple reductions in memory allocations +* Fake non-blocking read adapts its max wait time +* Improve CopyFrom performance and reduce memory usage +* Fix: encode []any to array +* Fix: LoadType for composite with dropped attributes (Felix Röhrich) +* Support v4 and v5 stdlib in same program +* Fix: text format array decoding with string of "NULL" +* Prefer binary format for arrays + # 5.2.0 (December 5, 2022) * `tracelog.TraceLog` implements the pgx.PrepareTracer interface. (Vitalii Solodilov) diff --git a/vendor/github.com/jackc/pgx/v5/README.md b/vendor/github.com/jackc/pgx/v5/README.md index c21a182..522206f 100644 --- a/vendor/github.com/jackc/pgx/v5/README.md +++ b/vendor/github.com/jackc/pgx/v5/README.md @@ -1,5 +1,5 @@ [![Go Reference](https://pkg.go.dev/badge/github.com/jackc/pgx/v5.svg)](https://pkg.go.dev/github.com/jackc/pgx/v5) -![Build Status](https://github.com/jackc/pgx/actions/workflows/ci.yml/badge.svg) +[![Build Status](https://github.com/jackc/pgx/actions/workflows/ci.yml/badge.svg)](https://github.com/jackc/pgx/actions/workflows/ci.yml) # pgx - PostgreSQL Driver and Toolkit @@ -88,7 +88,7 @@ See CONTRIBUTING.md for setup instructions. ## Supported Go and PostgreSQL Versions -pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.18 and higher and PostgreSQL 11 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/). +pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.19 and higher and PostgreSQL 11 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/). ## Version Policy @@ -132,13 +132,38 @@ These adapters can be used with the tracelog package. * [github.com/jackc/pgx-logrus](https://github.com/jackc/pgx-logrus) * [github.com/jackc/pgx-zap](https://github.com/jackc/pgx-zap) * [github.com/jackc/pgx-zerolog](https://github.com/jackc/pgx-zerolog) +* [github.com/mcosta74/pgx-slog](https://github.com/mcosta74/pgx-slog) +* [github.com/kataras/pgx-golog](https://github.com/kataras/pgx-golog) ## 3rd Party Libraries with PGX Support +### [github.com/pashagolub/pgxmock](https://github.com/pashagolub/pgxmock) + +pgxmock is a mock library implementing pgx interfaces. +pgxmock has one and only purpose - to simulate pgx behavior in tests, without needing a real database connection. + ### [github.com/georgysavva/scany](https://github.com/georgysavva/scany) Library for scanning data from a database into Go structs and more. +### [github.com/vingarcia/ksql](https://github.com/vingarcia/ksql) + +A carefully designed SQL client for making using SQL easier, +more productive, and less error-prone on Golang. + ### [https://github.com/otan/gopgkrb5](https://github.com/otan/gopgkrb5) Adds GSSAPI / Kerberos authentication support. + +### [github.com/wcamarao/pmx](https://github.com/wcamarao/pmx) + +Explicit data mapping and scanning library for Go structs and slices. + +### [github.com/stephenafamo/scan](https://github.com/stephenafamo/scan) + +Type safe and flexible package for scanning database data into Go types. +Supports, structs, maps, slices and custom mapping functions. + +### [https://github.com/z0ne-dev/mgx](https://github.com/z0ne-dev/mgx) + +Code first migration library for native pgx (no database/sql abstraction). diff --git a/vendor/github.com/jackc/pgx/v5/batch.go b/vendor/github.com/jackc/pgx/v5/batch.go index af62039..8f6ea4f 100644 --- a/vendor/github.com/jackc/pgx/v5/batch.go +++ b/vendor/github.com/jackc/pgx/v5/batch.go @@ -21,13 +21,10 @@ type batchItemFunc func(br BatchResults) error // Query sets fn to be called when the response to qq is received. func (qq *QueuedQuery) Query(fn func(rows Rows) error) { qq.fn = func(br BatchResults) error { - rows, err := br.Query() - if err != nil { - return err - } + rows, _ := br.Query() defer rows.Close() - err = fn(rows) + err := fn(rows) if err != nil { return err } @@ -142,7 +139,10 @@ func (br *batchResults) Exec() (pgconn.CommandTag, error) { } commandTag, err := br.mrr.ResultReader().Close() - br.err = err + if err != nil { + br.err = err + br.mrr.Close() + } if br.conn.batchTracer != nil { br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{ @@ -228,7 +228,7 @@ func (br *batchResults) Close() error { for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) { if br.b.queuedQueries[br.qqIdx].fn != nil { err := br.b.queuedQueries[br.qqIdx].fn(br) - if err != nil && br.err == nil { + if err != nil { br.err = err } } else { @@ -290,7 +290,7 @@ func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) { results, err := br.pipeline.GetResults() if err != nil { br.err = err - return pgconn.CommandTag{}, err + return pgconn.CommandTag{}, br.err } var commandTag pgconn.CommandTag switch results := results.(type) { @@ -309,7 +309,7 @@ func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) { }) } - return commandTag, err + return commandTag, br.err } // Query reads the results from the next query in the batch as if the query has been sent with Query. @@ -384,24 +384,20 @@ func (br *pipelineBatchResults) Close() error { } }() - if br.err != nil { - return br.err - } - - if br.lastRows != nil && br.lastRows.err != nil { + if br.err == nil && br.lastRows != nil && br.lastRows.err != nil { br.err = br.lastRows.err return br.err } if br.closed { - return nil + return br.err } // Read and run fn for all remaining items for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) { if br.b.queuedQueries[br.qqIdx].fn != nil { err := br.b.queuedQueries[br.qqIdx].fn(br) - if err != nil && br.err == nil { + if err != nil { br.err = err } } else { diff --git a/vendor/github.com/jackc/pgx/v5/conn.go b/vendor/github.com/jackc/pgx/v5/conn.go index 4c8b59d..7c7081b 100644 --- a/vendor/github.com/jackc/pgx/v5/conn.go +++ b/vendor/github.com/jackc/pgx/v5/conn.go @@ -178,7 +178,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con case "simple_protocol": defaultQueryExecMode = QueryExecModeSimpleProtocol default: - return nil, fmt.Errorf("invalid default_query_exec_mode: %v", err) + return nil, fmt.Errorf("invalid default_query_exec_mode: %s", s) } } @@ -194,20 +194,20 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con return connConfig, nil } -// ParseConfig creates a ConnConfig from a connection string. ParseConfig handles all options that pgconn.ParseConfig +// ParseConfig creates a ConnConfig from a connection string. ParseConfig handles all options that [pgconn.ParseConfig] // does. In addition, it accepts the following options: // -// default_query_exec_mode -// Possible values: "cache_statement", "cache_describe", "describe_exec", "exec", and "simple_protocol". See -// QueryExecMode constant documentation for the meaning of these values. Default: "cache_statement". +// - default_query_exec_mode. +// Possible values: "cache_statement", "cache_describe", "describe_exec", "exec", and "simple_protocol". See +// QueryExecMode constant documentation for the meaning of these values. Default: "cache_statement". // -// statement_cache_capacity -// The maximum size of the statement cache used when executing a query with "cache_statement" query exec mode. -// Default: 512. +// - statement_cache_capacity. +// The maximum size of the statement cache used when executing a query with "cache_statement" query exec mode. +// Default: 512. // -// description_cache_capacity -// The maximum size of the description cache used when executing a query with "cache_describe" query exec mode. -// Default: 512. +// - description_cache_capacity. +// The maximum size of the description cache used when executing a query with "cache_describe" query exec mode. +// Default: 512. func ParseConfig(connString string) (*ConnConfig, error) { return ParseConfigWithOptions(connString, ParseConfigOptions{}) } @@ -382,11 +382,9 @@ func quoteIdentifier(s string) string { return `"` + strings.ReplaceAll(s, `"`, `""`) + `"` } -// Ping executes an empty sql statement against the *Conn -// If the sql returns without error, the database Ping is considered successful, otherwise, the error is returned. +// Ping delegates to the underlying *pgconn.PgConn.Ping. func (c *Conn) Ping(ctx context.Context) error { - _, err := c.Exec(ctx, ";") - return err + return c.pgConn.Ping(ctx) } // PgConn returns the underlying *pgconn.PgConn. This is an escape hatch method that allows lower level access to the @@ -509,7 +507,7 @@ func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []a mrr := c.pgConn.Exec(ctx, sql) for mrr.NextResult() { - commandTag, err = mrr.ResultReader().Close() + commandTag, _ = mrr.ResultReader().Close() } err = mrr.Close() return commandTag, err @@ -585,8 +583,10 @@ const ( QueryExecModeCacheDescribe // Get the statement description on every execution. This uses the extended protocol. Queries require two round trips - // to execute. It does not use prepared statements (allowing usage with most connection poolers) and is safe even - // when the the database schema is modified concurrently. + // to execute. It does not use named prepared statements. But it does use the unnamed prepared statement to get the + // statement description on the first round trip and then uses it to execute the query on the second round trip. This + // may cause problems with connection poolers that switch the underlying connection between round trips. It is safe + // even when the the database schema is modified concurrently. QueryExecModeDescribeExec // Assume the PostgreSQL query parameter types based on the Go type of the arguments. This uses the extended protocol @@ -648,6 +648,9 @@ type QueryRewriter interface { // returned Rows even if an error is returned. The error will be the available in rows.Err() after rows are closed. It // is allowed to ignore the error returned from Query and handle it in Rows. // +// It is possible for a call of FieldDescriptions on the returned Rows to return nil even if the Query call did not +// return an error. +// // It is possible for a query to return one or more rows before encountering an error. In most cases the rows should be // collected before processing rather than processed while receiving each row. This avoids the possibility of the // application processing rows from a query that the server rejected. The CollectRows function is useful here. @@ -721,43 +724,10 @@ optionLoop: sd, explicitPreparedStatement := c.preparedStatements[sql] if sd != nil || mode == QueryExecModeCacheStatement || mode == QueryExecModeCacheDescribe || mode == QueryExecModeDescribeExec { if sd == nil { - switch mode { - case QueryExecModeCacheStatement: - if c.statementCache == nil { - err = errDisabledStatementCache - rows.fatal(err) - return rows, err - } - sd = c.statementCache.Get(sql) - if sd == nil { - sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql) - if err != nil { - rows.fatal(err) - return rows, err - } - c.statementCache.Put(sd) - } - case QueryExecModeCacheDescribe: - if c.descriptionCache == nil { - err = errDisabledDescriptionCache - rows.fatal(err) - return rows, err - } - sd = c.descriptionCache.Get(sql) - if sd == nil { - sd, err = c.Prepare(ctx, "", sql) - if err != nil { - rows.fatal(err) - return rows, err - } - c.descriptionCache.Put(sd) - } - case QueryExecModeDescribeExec: - sd, err = c.Prepare(ctx, "", sql) - if err != nil { - rows.fatal(err) - return rows, err - } + sd, err = c.getStatementDescription(ctx, mode, sql) + if err != nil { + rows.fatal(err) + return rows, err } } @@ -827,6 +797,48 @@ optionLoop: return rows, rows.err } +// getStatementDescription returns the statement description of the sql query +// according to the given mode. +// +// If the mode is one that doesn't require to know the param and result OIDs +// then nil is returned without error. +func (c *Conn) getStatementDescription( + ctx context.Context, + mode QueryExecMode, + sql string, +) (sd *pgconn.StatementDescription, err error) { + + switch mode { + case QueryExecModeCacheStatement: + if c.statementCache == nil { + return nil, errDisabledStatementCache + } + sd = c.statementCache.Get(sql) + if sd == nil { + sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql) + if err != nil { + return nil, err + } + c.statementCache.Put(sd) + } + case QueryExecModeCacheDescribe: + if c.descriptionCache == nil { + return nil, errDisabledDescriptionCache + } + sd = c.descriptionCache.Get(sql) + if sd == nil { + sd, err = c.Prepare(ctx, "", sql) + if err != nil { + return nil, err + } + c.descriptionCache.Put(sd) + } + case QueryExecModeDescribeExec: + return c.Prepare(ctx, "", sql) + } + return sd, err +} + // QueryRow is a convenience wrapper over Query. Any error that occurs while // querying is deferred until calling Scan on the returned Row. That Row will // error with ErrNoRows if no rows are returned. @@ -966,7 +978,7 @@ func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchR func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) { if c.statementCache == nil { - return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledStatementCache} + return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledStatementCache, closed: true} } distinctNewQueries := []*pgconn.StatementDescription{} @@ -998,7 +1010,7 @@ func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batc func (c *Conn) sendBatchQueryExecModeCacheDescribe(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) { if c.descriptionCache == nil { - return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledDescriptionCache} + return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledDescriptionCache, closed: true} } distinctNewQueries := []*pgconn.StatementDescription{} @@ -1052,7 +1064,7 @@ func (c *Conn) sendBatchQueryExecModeDescribeExec(ctx context.Context, b *Batch) func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, distinctNewQueries []*pgconn.StatementDescription, sdCache stmtcache.Cache) (pbr *pipelineBatchResults) { pipeline := c.pgConn.StartPipeline(context.Background()) defer func() { - if pbr.err != nil { + if pbr != nil && pbr.err != nil { pipeline.Close() } }() @@ -1065,18 +1077,18 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d err := pipeline.Sync() if err != nil { - return &pipelineBatchResults{ctx: ctx, conn: c, err: err} + return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} } for _, sd := range distinctNewQueries { results, err := pipeline.GetResults() if err != nil { - return &pipelineBatchResults{ctx: ctx, conn: c, err: err} + return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} } resultSD, ok := results.(*pgconn.StatementDescription) if !ok { - return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results)} + return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results), closed: true} } // Fill in the previously empty / pending statement descriptions. @@ -1086,12 +1098,12 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d results, err := pipeline.GetResults() if err != nil { - return &pipelineBatchResults{ctx: ctx, conn: c, err: err} + return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} } _, ok := results.(*pgconn.PipelineSync) if !ok { - return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results)} + return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results), closed: true} } } @@ -1106,7 +1118,9 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d for _, bi := range b.queuedQueries { err := c.eqb.Build(c.typeMap, bi.sd, bi.arguments) if err != nil { - return &pipelineBatchResults{ctx: ctx, conn: c, err: err} + // we wrap the error so we the user can understand which query failed inside the batch + err = fmt.Errorf("error building query %s: %w", bi.query, err) + return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} } if bi.sd.Name == "" { @@ -1118,7 +1132,7 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d err := pipeline.Sync() if err != nil { - return &pipelineBatchResults{ctx: ctx, conn: c, err: err} + return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} } return &pipelineBatchResults{ @@ -1272,6 +1286,8 @@ func (c *Conn) getCompositeFields(ctx context.Context, oid uint32) ([]pgtype.Com rows, _ := c.Query(ctx, `select attname, atttypid from pg_attribute where attrelid=$1 + and not attisdropped + and attnum > 0 order by attnum`, typrelid, ) @@ -1313,6 +1329,7 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error for _, sd := range invalidatedStatements { pipeline.SendDeallocate(sd.Name) + delete(c.preparedStatements, sd.Name) } err := pipeline.Sync() diff --git a/vendor/github.com/jackc/pgx/v5/copy_from.go b/vendor/github.com/jackc/pgx/v5/copy_from.go index 41acafd..a2c227f 100644 --- a/vendor/github.com/jackc/pgx/v5/copy_from.go +++ b/vendor/github.com/jackc/pgx/v5/copy_from.go @@ -85,6 +85,7 @@ type copyFrom struct { columnNames []string rowSrc CopyFromSource readerErrChan chan error + mode QueryExecMode } func (ct *copyFrom) run(ctx context.Context) (int64, error) { @@ -105,9 +106,29 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) { } quotedColumnNames := cbuf.String() - sd, err := ct.conn.Prepare(ctx, "", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName)) - if err != nil { - return 0, err + var sd *pgconn.StatementDescription + switch ct.mode { + case QueryExecModeExec, QueryExecModeSimpleProtocol: + // These modes don't support the binary format. Before the inclusion of the + // QueryExecModes, Conn.Prepare was called on every COPY operation to get + // the OIDs. These prepared statements were not cached. + // + // Since that's the same behavior provided by QueryExecModeDescribeExec, + // we'll default to that mode. + ct.mode = QueryExecModeDescribeExec + fallthrough + case QueryExecModeCacheStatement, QueryExecModeCacheDescribe, QueryExecModeDescribeExec: + var err error + sd, err = ct.conn.getStatementDescription( + ctx, + ct.mode, + fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName), + ) + if err != nil { + return 0, fmt.Errorf("statement description failed: %w", err) + } + default: + return 0, fmt.Errorf("unknown QueryExecMode: %v", ct.mode) } r, w := io.Pipe() @@ -167,8 +188,13 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) { } func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (bool, []byte, error) { + const sendBufSize = 65536 - 5 // The packet has a 5-byte header + lastBufLen := 0 + largestRowLen := 0 for ct.rowSrc.Next() { + lastBufLen = len(buf) + values, err := ct.rowSrc.Values() if err != nil { return false, nil, err @@ -185,7 +211,15 @@ func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (b } } - if len(buf) > 65536 { + rowLen := len(buf) - lastBufLen + if rowLen > largestRowLen { + largestRowLen = rowLen + } + + // Try not to overflow size of the buffer PgConn.CopyFrom will be reading into. If that happens then the nature of + // io.Pipe means that the next Read will be short. This can lead to pathological send sizes such as 65531, 13, 65531 + // 13, 65531, 13, 65531, 13. + if len(buf) > sendBufSize-largestRowLen { return true, buf, nil } } @@ -208,6 +242,7 @@ func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames [ columnNames: columnNames, rowSrc: rowSrc, readerErrChan: make(chan error), + mode: c.config.DefaultQueryExecMode, } return ct.run(ctx) diff --git a/vendor/github.com/jackc/pgx/v5/doc.go b/vendor/github.com/jackc/pgx/v5/doc.go index fb27187..7486f42 100644 --- a/vendor/github.com/jackc/pgx/v5/doc.go +++ b/vendor/github.com/jackc/pgx/v5/doc.go @@ -7,17 +7,17 @@ details. Establishing a Connection -The primary way of establishing a connection is with `pgx.Connect`. +The primary way of establishing a connection is with [pgx.Connect]: conn, err := pgx.Connect(context.Background(), os.Getenv("DATABASE_URL")) The database connection string can be in URL or DSN format. Both PostgreSQL settings and pgx settings can be specified -here. In addition, a config struct can be created by `ParseConfig` and modified before establishing the connection with -`ConnectConfig` to configure settings such as tracing that cannot be configured with a connection string. +here. In addition, a config struct can be created by [ParseConfig] and modified before establishing the connection with +[ConnectConfig] to configure settings such as tracing that cannot be configured with a connection string. Connection Pool -`*pgx.Conn` represents a single connection to the database and is not concurrency safe. Use package +[*pgx.Conn] represents a single connection to the database and is not concurrency safe. Use package github.com/jackc/pgx/v5/pgxpool for a concurrency safe connection pool. Query Interface @@ -69,8 +69,9 @@ Use Exec to execute a query that does not return a result set. PostgreSQL Data Types -The package pgtype provides extensive and customizable support for converting Go values to and from PostgreSQL values -including array and composite types. See that package's documentation for details. +pgx uses the pgtype package to converting Go values to and from PostgreSQL values. It supports many PostgreSQL types +directly and is customizable and extendable. User defined data types such as enums, domains, and composite types may +require type registration. See that package's documentation for details. Transactions diff --git a/vendor/github.com/jackc/pgx/v5/extended_query_builder.go b/vendor/github.com/jackc/pgx/v5/extended_query_builder.go index b0c0e02..0bbdfbb 100644 --- a/vendor/github.com/jackc/pgx/v5/extended_query_builder.go +++ b/vendor/github.com/jackc/pgx/v5/extended_query_builder.go @@ -1,6 +1,7 @@ package pgx import ( + "database/sql/driver" "fmt" "github.com/jackc/pgx/v5/internal/anynil" @@ -181,6 +182,19 @@ func (eqb *ExtendedQueryBuilder) appendParamsForQueryExecModeExec(m *pgtype.Map, } } } + if !ok { + var dv driver.Valuer + if dv, ok = arg.(driver.Valuer); ok { + v, err := dv.Value() + if err != nil { + return err + } + dt, ok = m.TypeForValue(v) + if ok { + arg = v + } + } + } if !ok { var str fmt.Stringer if str, ok = arg.(fmt.Stringer); ok { diff --git a/vendor/github.com/jackc/pgx/v5/internal/iobufpool/iobufpool.go b/vendor/github.com/jackc/pgx/v5/internal/iobufpool/iobufpool.go index 9e55c43..89e0c22 100644 --- a/vendor/github.com/jackc/pgx/v5/internal/iobufpool/iobufpool.go +++ b/vendor/github.com/jackc/pgx/v5/internal/iobufpool/iobufpool.go @@ -1,4 +1,7 @@ // Package iobufpool implements a global segregated-fit pool of buffers for IO. +// +// It uses *[]byte instead of []byte to avoid the sync.Pool allocation with Put. Unfortunately, using a pointer to avoid +// an allocation is purposely not documented. https://github.com/golang/go/issues/16323 package iobufpool import "sync" @@ -10,17 +13,27 @@ var pools [18]*sync.Pool func init() { for i := range pools { bufLen := 1 << (minPoolExpOf2 + i) - pools[i] = &sync.Pool{New: func() any { return make([]byte, bufLen) }} + pools[i] = &sync.Pool{ + New: func() any { + buf := make([]byte, bufLen) + return &buf + }, + } } } // Get gets a []byte of len size with cap <= size*2. -func Get(size int) []byte { +func Get(size int) *[]byte { i := getPoolIdx(size) if i >= len(pools) { - return make([]byte, size) + buf := make([]byte, size) + return &buf } - return pools[i].Get().([]byte)[:size] + + ptrBuf := (pools[i].Get().(*[]byte)) + *ptrBuf = (*ptrBuf)[:size] + + return ptrBuf } func getPoolIdx(size int) int { @@ -36,8 +49,8 @@ func getPoolIdx(size int) int { } // Put returns buf to the pool. -func Put(buf []byte) { - i := putPoolIdx(cap(buf)) +func Put(buf *[]byte) { + i := putPoolIdx(cap(*buf)) if i < 0 { return } diff --git a/vendor/github.com/jackc/pgx/v5/internal/nbconn/bufferqueue.go b/vendor/github.com/jackc/pgx/v5/internal/nbconn/bufferqueue.go deleted file mode 100644 index 138a9aa..0000000 --- a/vendor/github.com/jackc/pgx/v5/internal/nbconn/bufferqueue.go +++ /dev/null @@ -1,70 +0,0 @@ -package nbconn - -import ( - "sync" -) - -const minBufferQueueLen = 8 - -type bufferQueue struct { - lock sync.Mutex - queue [][]byte - r, w int -} - -func (bq *bufferQueue) pushBack(buf []byte) { - bq.lock.Lock() - defer bq.lock.Unlock() - - if bq.w >= len(bq.queue) { - bq.growQueue() - } - bq.queue[bq.w] = buf - bq.w++ -} - -func (bq *bufferQueue) pushFront(buf []byte) { - bq.lock.Lock() - defer bq.lock.Unlock() - - if bq.w >= len(bq.queue) { - bq.growQueue() - } - copy(bq.queue[bq.r+1:bq.w+1], bq.queue[bq.r:bq.w]) - bq.queue[bq.r] = buf - bq.w++ -} - -func (bq *bufferQueue) popFront() []byte { - bq.lock.Lock() - defer bq.lock.Unlock() - - if bq.r == bq.w { - return nil - } - - buf := bq.queue[bq.r] - bq.queue[bq.r] = nil // Clear reference so it can be garbage collected. - bq.r++ - - if bq.r == bq.w { - bq.r = 0 - bq.w = 0 - if len(bq.queue) > minBufferQueueLen { - bq.queue = make([][]byte, minBufferQueueLen) - } - } - - return buf -} - -func (bq *bufferQueue) growQueue() { - desiredLen := (len(bq.queue) + 1) * 3 / 2 - if desiredLen < minBufferQueueLen { - desiredLen = minBufferQueueLen - } - - newQueue := make([][]byte, desiredLen) - copy(newQueue, bq.queue) - bq.queue = newQueue -} diff --git a/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn.go b/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn.go deleted file mode 100644 index 9968851..0000000 --- a/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn.go +++ /dev/null @@ -1,478 +0,0 @@ -// Package nbconn implements a non-blocking net.Conn wrapper. -// -// It is designed to solve three problems. -// -// The first is resolving the deadlock that can occur when both sides of a connection are blocked writing because all -// buffers between are full. See https://github.com/jackc/pgconn/issues/27 for discussion. -// -// The second is the inability to use a write deadline with a TLS.Conn without killing the connection. -// -// The third is to efficiently check if a connection has been closed via a non-blocking read. -package nbconn - -import ( - "crypto/tls" - "errors" - "net" - "os" - "sync" - "sync/atomic" - "syscall" - "time" - - "github.com/jackc/pgx/v5/internal/iobufpool" -) - -var errClosed = errors.New("closed") -var ErrWouldBlock = new(wouldBlockError) - -const fakeNonblockingWaitDuration = 100 * time.Millisecond - -// NonBlockingDeadline is a magic value that when passed to Set[Read]Deadline places the connection in non-blocking read -// mode. -var NonBlockingDeadline = time.Date(1900, 1, 1, 0, 0, 0, 608536336, time.UTC) - -// disableSetDeadlineDeadline is a magic value that when passed to Set[Read|Write]Deadline causes those methods to -// ignore all future calls. -var disableSetDeadlineDeadline = time.Date(1900, 1, 1, 0, 0, 0, 968549727, time.UTC) - -// wouldBlockError implements net.Error so tls.Conn will recognize ErrWouldBlock as a temporary error. -type wouldBlockError struct{} - -func (*wouldBlockError) Error() string { - return "would block" -} - -func (*wouldBlockError) Timeout() bool { return true } -func (*wouldBlockError) Temporary() bool { return true } - -// Conn is a net.Conn where Write never blocks and always succeeds. Flush or Read must be called to actually write to -// the underlying connection. -type Conn interface { - net.Conn - - // Flush flushes any buffered writes. - Flush() error - - // BufferReadUntilBlock reads and buffers any sucessfully read bytes until the read would block. - BufferReadUntilBlock() error -} - -// NetConn is a non-blocking net.Conn wrapper. It implements net.Conn. -type NetConn struct { - // 64 bit fields accessed with atomics must be at beginning of struct to guarantee alignment for certain 32-bit - // architectures. See BUGS section of https://pkg.go.dev/sync/atomic and https://github.com/jackc/pgx/issues/1288 and - // https://github.com/jackc/pgx/issues/1307. Only access with atomics - closed int64 // 0 = not closed, 1 = closed - - conn net.Conn - rawConn syscall.RawConn - - readQueue bufferQueue - writeQueue bufferQueue - - readFlushLock sync.Mutex - // non-blocking writes with syscall.RawConn are done with a callback function. By using these fields instead of the - // callback functions closure to pass the buf argument and receive the n and err results we avoid some allocations. - nonblockWriteBuf []byte - nonblockWriteErr error - nonblockWriteN int - - readDeadlineLock sync.Mutex - readDeadline time.Time - readNonblocking bool - - writeDeadlineLock sync.Mutex - writeDeadline time.Time -} - -func NewNetConn(conn net.Conn, fakeNonBlockingIO bool) *NetConn { - nc := &NetConn{ - conn: conn, - } - - if !fakeNonBlockingIO { - if sc, ok := conn.(syscall.Conn); ok { - if rawConn, err := sc.SyscallConn(); err == nil { - nc.rawConn = rawConn - } - } - } - - return nc -} - -// Read implements io.Reader. -func (c *NetConn) Read(b []byte) (n int, err error) { - if c.isClosed() { - return 0, errClosed - } - - c.readFlushLock.Lock() - defer c.readFlushLock.Unlock() - - err = c.flush() - if err != nil { - return 0, err - } - - for n < len(b) { - buf := c.readQueue.popFront() - if buf == nil { - break - } - copiedN := copy(b[n:], buf) - if copiedN < len(buf) { - buf = buf[copiedN:] - c.readQueue.pushFront(buf) - } else { - iobufpool.Put(buf) - } - n += copiedN - } - - // If any bytes were already buffered return them without trying to do a Read. Otherwise, when the caller is trying to - // Read up to len(b) bytes but all available bytes have already been buffered the underlying Read would block. - if n > 0 { - return n, nil - } - - var readNonblocking bool - c.readDeadlineLock.Lock() - readNonblocking = c.readNonblocking - c.readDeadlineLock.Unlock() - - var readN int - if readNonblocking { - readN, err = c.nonblockingRead(b[n:]) - } else { - readN, err = c.conn.Read(b[n:]) - } - n += readN - return n, err -} - -// Write implements io.Writer. It never blocks due to buffering all writes. It will only return an error if the Conn is -// closed. Call Flush to actually write to the underlying connection. -func (c *NetConn) Write(b []byte) (n int, err error) { - if c.isClosed() { - return 0, errClosed - } - - buf := iobufpool.Get(len(b)) - copy(buf, b) - c.writeQueue.pushBack(buf) - return len(b), nil -} - -func (c *NetConn) Close() (err error) { - swapped := atomic.CompareAndSwapInt64(&c.closed, 0, 1) - if !swapped { - return errClosed - } - - defer func() { - closeErr := c.conn.Close() - if err == nil { - err = closeErr - } - }() - - c.readFlushLock.Lock() - defer c.readFlushLock.Unlock() - err = c.flush() - if err != nil { - return err - } - - return nil -} - -func (c *NetConn) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} - -func (c *NetConn) RemoteAddr() net.Addr { - return c.conn.RemoteAddr() -} - -// SetDeadline is the equivalent of calling SetReadDealine(t) and SetWriteDeadline(t). -func (c *NetConn) SetDeadline(t time.Time) error { - err := c.SetReadDeadline(t) - if err != nil { - return err - } - return c.SetWriteDeadline(t) -} - -// SetReadDeadline sets the read deadline as t. If t == NonBlockingDeadline then future reads will be non-blocking. -func (c *NetConn) SetReadDeadline(t time.Time) error { - if c.isClosed() { - return errClosed - } - - c.readDeadlineLock.Lock() - defer c.readDeadlineLock.Unlock() - if c.readDeadline == disableSetDeadlineDeadline { - return nil - } - if t == disableSetDeadlineDeadline { - c.readDeadline = t - return nil - } - - if t == NonBlockingDeadline { - c.readNonblocking = true - t = time.Time{} - } else { - c.readNonblocking = false - } - - c.readDeadline = t - - return c.conn.SetReadDeadline(t) -} - -func (c *NetConn) SetWriteDeadline(t time.Time) error { - if c.isClosed() { - return errClosed - } - - c.writeDeadlineLock.Lock() - defer c.writeDeadlineLock.Unlock() - if c.writeDeadline == disableSetDeadlineDeadline { - return nil - } - if t == disableSetDeadlineDeadline { - c.writeDeadline = t - return nil - } - - c.writeDeadline = t - - return c.conn.SetWriteDeadline(t) -} - -func (c *NetConn) Flush() error { - if c.isClosed() { - return errClosed - } - - c.readFlushLock.Lock() - defer c.readFlushLock.Unlock() - return c.flush() -} - -// flush does the actual work of flushing the writeQueue. readFlushLock must already be held. -func (c *NetConn) flush() error { - var stopChan chan struct{} - var errChan chan error - - defer func() { - if stopChan != nil { - select { - case stopChan <- struct{}{}: - case <-errChan: - } - } - }() - - for buf := c.writeQueue.popFront(); buf != nil; buf = c.writeQueue.popFront() { - remainingBuf := buf - for len(remainingBuf) > 0 { - n, err := c.nonblockingWrite(remainingBuf) - remainingBuf = remainingBuf[n:] - if err != nil { - if !errors.Is(err, ErrWouldBlock) { - buf = buf[:len(remainingBuf)] - copy(buf, remainingBuf) - c.writeQueue.pushFront(buf) - return err - } - - // Writing was blocked. Reading might unblock it. - if stopChan == nil { - stopChan, errChan = c.bufferNonblockingRead() - } - - select { - case err := <-errChan: - stopChan = nil - return err - default: - } - - } - } - iobufpool.Put(buf) - } - - return nil -} - -func (c *NetConn) BufferReadUntilBlock() error { - for { - buf := iobufpool.Get(8 * 1024) - n, err := c.nonblockingRead(buf) - if n > 0 { - buf = buf[:n] - c.readQueue.pushBack(buf) - } - - if err != nil { - if errors.Is(err, ErrWouldBlock) { - return nil - } else { - return err - } - } - } -} - -func (c *NetConn) bufferNonblockingRead() (stopChan chan struct{}, errChan chan error) { - stopChan = make(chan struct{}) - errChan = make(chan error, 1) - - go func() { - for { - err := c.BufferReadUntilBlock() - if err != nil { - errChan <- err - return - } - - select { - case <-stopChan: - return - default: - } - } - }() - - return stopChan, errChan -} - -func (c *NetConn) isClosed() bool { - closed := atomic.LoadInt64(&c.closed) - return closed == 1 -} - -func (c *NetConn) nonblockingWrite(b []byte) (n int, err error) { - if c.rawConn == nil { - return c.fakeNonblockingWrite(b) - } else { - return c.realNonblockingWrite(b) - } -} - -func (c *NetConn) fakeNonblockingWrite(b []byte) (n int, err error) { - c.writeDeadlineLock.Lock() - defer c.writeDeadlineLock.Unlock() - - deadline := time.Now().Add(fakeNonblockingWaitDuration) - if c.writeDeadline.IsZero() || deadline.Before(c.writeDeadline) { - err = c.conn.SetWriteDeadline(deadline) - if err != nil { - return 0, err - } - defer func() { - // Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails. - c.conn.SetWriteDeadline(c.writeDeadline) - - if err != nil { - if errors.Is(err, os.ErrDeadlineExceeded) { - err = ErrWouldBlock - } - } - }() - } - - return c.conn.Write(b) -} - -func (c *NetConn) nonblockingRead(b []byte) (n int, err error) { - if c.rawConn == nil { - return c.fakeNonblockingRead(b) - } else { - return c.realNonblockingRead(b) - } -} - -func (c *NetConn) fakeNonblockingRead(b []byte) (n int, err error) { - c.readDeadlineLock.Lock() - defer c.readDeadlineLock.Unlock() - - deadline := time.Now().Add(fakeNonblockingWaitDuration) - if c.readDeadline.IsZero() || deadline.Before(c.readDeadline) { - err = c.conn.SetReadDeadline(deadline) - if err != nil { - return 0, err - } - defer func() { - // Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails. - c.conn.SetReadDeadline(c.readDeadline) - - if err != nil { - if errors.Is(err, os.ErrDeadlineExceeded) { - err = ErrWouldBlock - } - } - }() - } - - return c.conn.Read(b) -} - -// syscall.Conn is interface - -// TLSClient establishes a TLS connection as a client over conn using config. -// -// To avoid the first Read on the returned *TLSConn also triggering a Write due to the TLS handshake and thereby -// potentially causing a read and write deadlines to behave unexpectedly, Handshake is called explicitly before the -// *TLSConn is returned. -func TLSClient(conn *NetConn, config *tls.Config) (*TLSConn, error) { - tc := tls.Client(conn, config) - err := tc.Handshake() - if err != nil { - return nil, err - } - - // Ensure last written part of Handshake is actually sent. - err = conn.Flush() - if err != nil { - return nil, err - } - - return &TLSConn{ - tlsConn: tc, - nbConn: conn, - }, nil -} - -// TLSConn is a TLS wrapper around a *Conn. It works around a temporary write error (such as a timeout) being fatal to a -// tls.Conn. -type TLSConn struct { - tlsConn *tls.Conn - nbConn *NetConn -} - -func (tc *TLSConn) Read(b []byte) (n int, err error) { return tc.tlsConn.Read(b) } -func (tc *TLSConn) Write(b []byte) (n int, err error) { return tc.tlsConn.Write(b) } -func (tc *TLSConn) BufferReadUntilBlock() error { return tc.nbConn.BufferReadUntilBlock() } -func (tc *TLSConn) Flush() error { return tc.nbConn.Flush() } -func (tc *TLSConn) LocalAddr() net.Addr { return tc.tlsConn.LocalAddr() } -func (tc *TLSConn) RemoteAddr() net.Addr { return tc.tlsConn.RemoteAddr() } - -func (tc *TLSConn) Close() error { - // tls.Conn.closeNotify() sets a 5 second deadline to avoid blocking, sends a TLS alert close notification, and then - // sets the deadline to now. This causes NetConn's Close not to be able to flush the write buffer. Instead we set our - // own 5 second deadline then make all set deadlines no-op. - tc.tlsConn.SetDeadline(time.Now().Add(time.Second * 5)) - tc.tlsConn.SetDeadline(disableSetDeadlineDeadline) - - return tc.tlsConn.Close() -} - -func (tc *TLSConn) SetDeadline(t time.Time) error { return tc.tlsConn.SetDeadline(t) } -func (tc *TLSConn) SetReadDeadline(t time.Time) error { return tc.tlsConn.SetReadDeadline(t) } -func (tc *TLSConn) SetWriteDeadline(t time.Time) error { return tc.tlsConn.SetWriteDeadline(t) } diff --git a/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_fake_non_block.go b/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_fake_non_block.go deleted file mode 100644 index cf05df1..0000000 --- a/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_fake_non_block.go +++ /dev/null @@ -1,13 +0,0 @@ -//go:build !(aix || android || darwin || dragonfly || freebsd || hurd || illumos || ios || linux || netbsd || openbsd || solaris) - -package nbconn - -// Not using unix build tag for support on Go 1.18. - -func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) { - return c.fakeNonblockingWrite(b) -} - -func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) { - return c.fakeNonblockingRead(b) -} diff --git a/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_real_non_block.go b/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_real_non_block.go deleted file mode 100644 index ee48d12..0000000 --- a/vendor/github.com/jackc/pgx/v5/internal/nbconn/nbconn_real_non_block.go +++ /dev/null @@ -1,70 +0,0 @@ -//go:build aix || android || darwin || dragonfly || freebsd || hurd || illumos || ios || linux || netbsd || openbsd || solaris - -package nbconn - -// Not using unix build tag for support on Go 1.18. - -import ( - "errors" - "io" - "syscall" -) - -// realNonblockingWrite does a non-blocking write. readFlushLock must already be held. -func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) { - c.nonblockWriteBuf = b - c.nonblockWriteN = 0 - c.nonblockWriteErr = nil - err = c.rawConn.Write(func(fd uintptr) (done bool) { - c.nonblockWriteN, c.nonblockWriteErr = syscall.Write(int(fd), c.nonblockWriteBuf) - return true - }) - n = c.nonblockWriteN - if err == nil && c.nonblockWriteErr != nil { - if errors.Is(c.nonblockWriteErr, syscall.EWOULDBLOCK) { - err = ErrWouldBlock - } else { - err = c.nonblockWriteErr - } - } - if err != nil { - // n may be -1 when an error occurs. - if n < 0 { - n = 0 - } - - return n, err - } - - return n, nil -} - -func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) { - var funcErr error - err = c.rawConn.Read(func(fd uintptr) (done bool) { - n, funcErr = syscall.Read(int(fd), b) - return true - }) - if err == nil && funcErr != nil { - if errors.Is(funcErr, syscall.EWOULDBLOCK) { - err = ErrWouldBlock - } else { - err = funcErr - } - } - if err != nil { - // n may be -1 when an error occurs. - if n < 0 { - n = 0 - } - - return n, err - } - - // syscall read did not return an error and 0 bytes were read means EOF. - if n == 0 { - return 0, io.EOF - } - - return n, nil -} diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/auth_scram.go b/vendor/github.com/jackc/pgx/v5/pgconn/auth_scram.go index 6ca9e33..8c4b2de 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/auth_scram.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/auth_scram.go @@ -42,7 +42,7 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { Data: sc.clientFirstMessage(), } c.frontend.Send(saslInitialResponse) - err = c.frontend.Flush() + err = c.flushWithPotentialWriteReadDeadlock() if err != nil { return err } @@ -62,7 +62,7 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { Data: []byte(sc.clientFinalMessage()), } c.frontend.Send(saslResponse) - err = c.frontend.Flush() + err = c.flushWithPotentialWriteReadDeadlock() if err != nil { return err } diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/config.go b/vendor/github.com/jackc/pgx/v5/pgconn/config.go index 6282f41..1c2c647 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/config.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/config.go @@ -8,7 +8,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "math" "net" "net/url" @@ -27,7 +26,7 @@ type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error type GetSSLPasswordFunc func(ctx context.Context) string -// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig. A +// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by [ParseConfig]. A // manually initialized Config will cause ConnectConfig to panic. type Config struct { Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp) @@ -211,9 +210,9 @@ func NetworkAddress(host string, port uint16) (network, address string) { // // In addition, ParseConfig accepts the following options: // -// servicefile -// libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a -// part of the connection string. +// - servicefile. +// libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a +// part of the connection string. func ParseConfig(connString string) (*Config, error) { var parseConfigOptions ParseConfigOptions return ParseConfigWithOptions(connString, parseConfigOptions) @@ -687,7 +686,7 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P caCertPool := x509.NewCertPool() caPath := sslrootcert - caCert, err := ioutil.ReadFile(caPath) + caCert, err := os.ReadFile(caPath) if err != nil { return nil, fmt.Errorf("unable to read CA file: %w", err) } @@ -705,7 +704,7 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P } if sslcert != "" && sslkey != "" { - buf, err := ioutil.ReadFile(sslkey) + buf, err := os.ReadFile(sslkey) if err != nil { return nil, fmt.Errorf("unable to read sslkey: %w", err) } @@ -744,7 +743,7 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P } else { pemKey = pem.EncodeToMemory(block) } - certfile, err := ioutil.ReadFile(sslcert) + certfile, err := os.ReadFile(sslcert) if err != nil { return nil, fmt.Errorf("unable to read cert: %w", err) } diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/internal/bgreader/bgreader.go b/vendor/github.com/jackc/pgx/v5/pgconn/internal/bgreader/bgreader.go new file mode 100644 index 0000000..e65c2c2 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/pgconn/internal/bgreader/bgreader.go @@ -0,0 +1,139 @@ +// Package bgreader provides a io.Reader that can optionally buffer reads in the background. +package bgreader + +import ( + "io" + "sync" + + "github.com/jackc/pgx/v5/internal/iobufpool" +) + +const ( + StatusStopped = iota + StatusRunning + StatusStopping +) + +// BGReader is an io.Reader that can optionally buffer reads in the background. It is safe for concurrent use. +type BGReader struct { + r io.Reader + + cond *sync.Cond + status int32 + readResults []readResult +} + +type readResult struct { + buf *[]byte + err error +} + +// Start starts the backgrounder reader. If the background reader is already running this is a no-op. The background +// reader will stop automatically when the underlying reader returns an error. +func (r *BGReader) Start() { + r.cond.L.Lock() + defer r.cond.L.Unlock() + + switch r.status { + case StatusStopped: + r.status = StatusRunning + go r.bgRead() + case StatusRunning: + // no-op + case StatusStopping: + r.status = StatusRunning + } +} + +// Stop tells the background reader to stop after the in progress Read returns. It is safe to call Stop when the +// background reader is not running. +func (r *BGReader) Stop() { + r.cond.L.Lock() + defer r.cond.L.Unlock() + + switch r.status { + case StatusStopped: + // no-op + case StatusRunning: + r.status = StatusStopping + case StatusStopping: + // no-op + } +} + +// Status returns the current status of the background reader. +func (r *BGReader) Status() int32 { + r.cond.L.Lock() + defer r.cond.L.Unlock() + return r.status +} + +func (r *BGReader) bgRead() { + keepReading := true + for keepReading { + buf := iobufpool.Get(8192) + n, err := r.r.Read(*buf) + *buf = (*buf)[:n] + + r.cond.L.Lock() + r.readResults = append(r.readResults, readResult{buf: buf, err: err}) + if r.status == StatusStopping || err != nil { + r.status = StatusStopped + keepReading = false + } + r.cond.L.Unlock() + r.cond.Broadcast() + } +} + +// Read implements the io.Reader interface. +func (r *BGReader) Read(p []byte) (int, error) { + r.cond.L.Lock() + defer r.cond.L.Unlock() + + if len(r.readResults) > 0 { + return r.readFromReadResults(p) + } + + // There are no unread background read results and the background reader is stopped. + if r.status == StatusStopped { + return r.r.Read(p) + } + + // Wait for results from the background reader + for len(r.readResults) == 0 { + r.cond.Wait() + } + return r.readFromReadResults(p) +} + +// readBackgroundResults reads a result previously read by the background reader. r.cond.L must be held. +func (r *BGReader) readFromReadResults(p []byte) (int, error) { + buf := r.readResults[0].buf + var err error + + n := copy(p, *buf) + if n == len(*buf) { + err = r.readResults[0].err + iobufpool.Put(buf) + if len(r.readResults) == 1 { + r.readResults = nil + } else { + r.readResults = r.readResults[1:] + } + } else { + *buf = (*buf)[n:] + r.readResults[0].buf = buf + } + + return n, err +} + +func New(r io.Reader) *BGReader { + return &BGReader{ + r: r, + cond: &sync.Cond{ + L: &sync.Mutex{}, + }, + } +} diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/krb5.go b/vendor/github.com/jackc/pgx/v5/pgconn/krb5.go index 969675f..3c1af34 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/krb5.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/krb5.go @@ -63,7 +63,7 @@ func (c *PgConn) gssAuth() error { Data: nextData, } c.frontend.Send(gssResponse) - err = c.frontend.Flush() + err = c.flushWithPotentialWriteReadDeadlock() if err != nil { return err } diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go b/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go index 59fa35c..8f602e4 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go @@ -13,11 +13,12 @@ import ( "net" "strconv" "strings" + "sync" "time" "github.com/jackc/pgx/v5/internal/iobufpool" - "github.com/jackc/pgx/v5/internal/nbconn" "github.com/jackc/pgx/v5/internal/pgio" + "github.com/jackc/pgx/v5/pgconn/internal/bgreader" "github.com/jackc/pgx/v5/pgconn/internal/ctxwatch" "github.com/jackc/pgx/v5/pgproto3" ) @@ -65,17 +66,24 @@ type NotificationHandler func(*PgConn, *Notification) // PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. type PgConn struct { - conn nbconn.Conn // the non-blocking wrapper for the underlying TCP or unix domain socket connection + conn net.Conn pid uint32 // backend pid secretKey uint32 // key to use to send a cancel query message to the server parameterStatuses map[string]string // parameters that have been reported by the server txStatus byte frontend *pgproto3.Frontend + bgReader *bgreader.BGReader + slowWriteTimer *time.Timer config *Config status byte // One of connStatus* constants + bufferingReceive bool + bufferingReceiveMux sync.Mutex + bufferingReceiveMsg pgproto3.BackendMessage + bufferingReceiveErr error + peekedMsg pgproto3.BackendMessage // Reusable / preallocated resources @@ -89,7 +97,7 @@ type PgConn struct { } // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) -// to provide configuration. See documentation for ParseConfig for details. ctx can be used to cancel a connect attempt. +// to provide configuration. See documentation for [ParseConfig] for details. ctx can be used to cancel a connect attempt. func Connect(ctx context.Context, connString string) (*PgConn, error) { config, err := ParseConfig(connString) if err != nil { @@ -100,7 +108,7 @@ func Connect(ctx context.Context, connString string) (*PgConn, error) { } // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) -// and ParseConfigOptions to provide additional configuration. See documentation for ParseConfig for details. ctx can be +// and ParseConfigOptions to provide additional configuration. See documentation for [ParseConfig] for details. ctx can be // used to cancel a connect attempt. func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptions ParseConfigOptions) (*PgConn, error) { config, err := ParseConfigWithOptions(connString, parseConfigOptions) @@ -112,7 +120,7 @@ func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptio } // Connect establishes a connection to a PostgreSQL server using config. config must have been constructed with -// ParseConfig. ctx can be used to cancel a connect attempt. +// [ParseConfig]. ctx can be used to cancel a connect attempt. // // If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An // authentication error will terminate the chain of attempts (like libpq: @@ -146,12 +154,15 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er foundBestServer := false var fallbackConfig *FallbackConfig - for _, fc := range fallbackConfigs { + for i, fc := range fallbackConfigs { // ConnectTimeout restricts the whole connection process. if config.ConnectTimeout != 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(octx, config.ConnectTimeout) - defer cancel() + // create new context first time or when previous host was different + if i == 0 || (fallbackConfigs[i].Host != fallbackConfigs[i-1].Host) { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(octx, config.ConnectTimeout) + defer cancel() + } } else { ctx = octx } @@ -166,7 +177,7 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege if pgerr.Code == ERRCODE_INVALID_PASSWORD || - pgerr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION || + pgerr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION && fc.TLSConfig != nil || pgerr.Code == ERRCODE_INVALID_CATALOG_NAME || pgerr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE { break @@ -203,6 +214,8 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*FallbackConfig) ([]*FallbackConfig, error) { var configs []*FallbackConfig + var lookupErrors []error + for _, fb := range fallbacks { // skip resolve for unix sockets if isAbsolutePath(fb.Host) { @@ -217,7 +230,8 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba ips, err := lookupFn(ctx, fb.Host) if err != nil { - return nil, err + lookupErrors = append(lookupErrors, err) + continue } for _, ip := range ips { @@ -242,11 +256,18 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba } } + // See https://github.com/jackc/pgx/issues/1464. When Go 1.20 can be used in pgx consider using errors.Join so all + // errors are reported. + if len(configs) == 0 && len(lookupErrors) > 0 { + return nil, lookupErrors[0] + } + return configs, nil } func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig, - ignoreNotPreferredErr bool) (*PgConn, error) { + ignoreNotPreferredErr bool, +) (*PgConn, error) { pgConn := new(PgConn) pgConn.config = config pgConn.cleanupDone = make(chan struct{}) @@ -257,14 +278,13 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig if err != nil { return nil, &connectError{config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)} } - nbNetConn := nbconn.NewNetConn(netConn, false) - pgConn.conn = nbNetConn - pgConn.contextWatcher = newContextWatcher(nbNetConn) + pgConn.conn = netConn + pgConn.contextWatcher = newContextWatcher(netConn) pgConn.contextWatcher.Watch(ctx) if fallbackConfig.TLSConfig != nil { - nbTLSConn, err := startTLS(nbNetConn, fallbackConfig.TLSConfig) + nbTLSConn, err := startTLS(netConn, fallbackConfig.TLSConfig) pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS. if err != nil { netConn.Close() @@ -280,7 +300,10 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn.parameterStatuses = make(map[string]string) pgConn.status = connStatusConnecting - pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn) + pgConn.bgReader = bgreader.New(pgConn.conn) + pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), pgConn.bgReader.Start) + pgConn.slowWriteTimer.Stop() + pgConn.frontend = config.BuildFrontend(pgConn.bgReader, pgConn.conn) startupMsg := pgproto3.StartupMessage{ ProtocolVersion: pgproto3.ProtocolVersionNumber, @@ -298,9 +321,9 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } pgConn.frontend.Send(&startupMsg) - if err := pgConn.frontend.Flush(); err != nil { + if err := pgConn.flushWithPotentialWriteReadDeadlock(); err != nil { pgConn.conn.Close() - return nil, &connectError{config: config, msg: "failed to write startup message", err: err} + return nil, &connectError{config: config, msg: "failed to write startup message", err: normalizeTimeoutError(ctx, err)} } for { @@ -383,7 +406,7 @@ func newContextWatcher(conn net.Conn) *ctxwatch.ContextWatcher { ) } -func startTLS(conn *nbconn.NetConn, tlsConfig *tls.Config) (*nbconn.TLSConn, error) { +func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103}) if err != nil { return nil, err @@ -398,17 +421,12 @@ func startTLS(conn *nbconn.NetConn, tlsConfig *tls.Config) (*nbconn.TLSConn, err return nil, errors.New("server refused TLS connection") } - tlsConn, err := nbconn.TLSClient(conn, tlsConfig) - if err != nil { - return nil, err - } - - return tlsConn, nil + return tls.Client(conn, tlsConfig), nil } func (pgConn *PgConn) txPasswordMessage(password string) (err error) { pgConn.frontend.Send(&pgproto3.PasswordMessage{Password: password}) - return pgConn.frontend.Flush() + return pgConn.flushWithPotentialWriteReadDeadlock() } func hexMD5(s string) string { @@ -417,6 +435,24 @@ func hexMD5(s string) string { return hex.EncodeToString(hash.Sum(nil)) } +func (pgConn *PgConn) signalMessage() chan struct{} { + if pgConn.bufferingReceive { + panic("BUG: signalMessage when already in progress") + } + + pgConn.bufferingReceive = true + pgConn.bufferingReceiveMux.Lock() + + ch := make(chan struct{}) + go func() { + pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.frontend.Receive() + pgConn.bufferingReceiveMux.Unlock() + close(ch) + }() + + return ch +} + // ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the // connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages // are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger @@ -445,7 +481,8 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa err = &pgconnError{ msg: "receive message failed", err: normalizeTimeoutError(ctx, err), - safeToRetry: true} + safeToRetry: true, + } } return msg, err } @@ -456,13 +493,25 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) { return pgConn.peekedMsg, nil } - msg, err := pgConn.frontend.Receive() + var msg pgproto3.BackendMessage + var err error + if pgConn.bufferingReceive { + pgConn.bufferingReceiveMux.Lock() + msg = pgConn.bufferingReceiveMsg + err = pgConn.bufferingReceiveErr + pgConn.bufferingReceiveMux.Unlock() + pgConn.bufferingReceive = false + + // If a timeout error happened in the background try the read again. + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + msg, err = pgConn.frontend.Receive() + } + } else { + msg, err = pgConn.frontend.Receive() + } if err != nil { - if errors.Is(err, nbconn.ErrWouldBlock) { - return nil, err - } - // Close on anything other than timeout error - everything else is fatal var netErr net.Error isNetErr := errors.As(err, &netErr) @@ -510,7 +559,8 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { return msg, nil } -// Conn returns the underlying net.Conn. This rarely necessary. +// Conn returns the underlying net.Conn. This rarely necessary. If the connection will be directly used for reading or +// writing then SyncConn should usually be called before Conn. func (pgConn *PgConn) Conn() net.Conn { return pgConn.conn } @@ -573,7 +623,7 @@ func (pgConn *PgConn) Close(ctx context.Context) error { // // See https://github.com/jackc/pgx/issues/637 pgConn.frontend.Send(&pgproto3.Terminate{}) - pgConn.frontend.Flush() + pgConn.flushWithPotentialWriteReadDeadlock() return pgConn.conn.Close() } @@ -600,7 +650,7 @@ func (pgConn *PgConn) asyncClose() { pgConn.conn.SetDeadline(deadline) pgConn.frontend.Send(&pgproto3.Terminate{}) - pgConn.frontend.Flush() + pgConn.flushWithPotentialWriteReadDeadlock() }() } @@ -775,7 +825,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ pgConn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}) pgConn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name}) pgConn.frontend.SendSync(&pgproto3.Sync{}) - err := pgConn.frontend.Flush() + err := pgConn.flushWithPotentialWriteReadDeadlock() if err != nil { pgConn.asyncClose() return nil, err @@ -848,9 +898,28 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // the connection config. This is important in high availability configurations where fallback connections may be // specified or DNS may be used to load balance. serverAddr := pgConn.conn.RemoteAddr() - cancelConn, err := pgConn.config.DialFunc(ctx, serverAddr.Network(), serverAddr.String()) + var serverNetwork string + var serverAddress string + if serverAddr.Network() == "unix" { + // for unix sockets, RemoteAddr() calls getpeername() which returns the name the + // server passed to bind(). For Postgres, this is always a relative path "./.s.PGSQL.5432" + // so connecting to it will fail. Fall back to the config's value + serverNetwork, serverAddress = NetworkAddress(pgConn.config.Host, pgConn.config.Port) + } else { + serverNetwork, serverAddress = serverAddr.Network(), serverAddr.String() + } + cancelConn, err := pgConn.config.DialFunc(ctx, serverNetwork, serverAddress) if err != nil { - return err + // In case of unix sockets, RemoteAddr() returns only the file part of the path. If the + // first connect failed, try the config. + if serverAddr.Network() != "unix" { + return err + } + serverNetwork, serverAddr := NetworkAddress(pgConn.config.Host, pgConn.config.Port) + cancelConn, err = pgConn.config.DialFunc(ctx, serverNetwork, serverAddr) + if err != nil { + return err + } } defer cancelConn.Close() @@ -868,17 +937,11 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { binary.BigEndian.PutUint32(buf[4:8], 80877102) binary.BigEndian.PutUint32(buf[8:12], uint32(pgConn.pid)) binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey)) + // Postgres will process the request and close the connection + // so when don't need to read the reply + // https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.6.7.10 _, err = cancelConn.Write(buf) - if err != nil { - return err - } - - _, err = cancelConn.Read(buf) - if err != io.EOF { - return err - } - - return nil + return err } // WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not @@ -944,7 +1007,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { } pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) - err := pgConn.frontend.Flush() + err := pgConn.flushWithPotentialWriteReadDeadlock() if err != nil { pgConn.asyncClose() pgConn.contextWatcher.Unwatch() @@ -1055,7 +1118,7 @@ func (pgConn *PgConn) execExtendedSuffix(result *ResultReader) { pgConn.frontend.SendExecute(&pgproto3.Execute{}) pgConn.frontend.SendSync(&pgproto3.Sync{}) - err := pgConn.frontend.Flush() + err := pgConn.flushWithPotentialWriteReadDeadlock() if err != nil { pgConn.asyncClose() result.concludeCommand(CommandTag{}, err) @@ -1088,7 +1151,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm // Send copy to command pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) - err := pgConn.frontend.Flush() + err := pgConn.flushWithPotentialWriteReadDeadlock() if err != nil { pgConn.asyncClose() pgConn.unlock() @@ -1144,85 +1207,91 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co defer pgConn.contextWatcher.Unwatch() } - // Send copy to command + // Send copy from query pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) - err := pgConn.frontend.Flush() + err := pgConn.flushWithPotentialWriteReadDeadlock() if err != nil { pgConn.asyncClose() return CommandTag{}, err } - err = pgConn.conn.SetReadDeadline(nbconn.NonBlockingDeadline) - if err != nil { - pgConn.asyncClose() - return CommandTag{}, err - } - nonblocking := true - defer func() { - if nonblocking { - pgConn.conn.SetReadDeadline(time.Time{}) + // Send copy data + abortCopyChan := make(chan struct{}) + copyErrChan := make(chan error, 1) + signalMessageChan := pgConn.signalMessage() + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + buf := iobufpool.Get(65536) + defer iobufpool.Put(buf) + (*buf)[0] = 'd' + + for { + n, readErr := r.Read((*buf)[5:cap(*buf)]) + if n > 0 { + *buf = (*buf)[0 : n+5] + pgio.SetInt32((*buf)[1:], int32(n+4)) + + writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(*buf) + if writeErr != nil { + // Write errors are always fatal, but we can't use asyncClose because we are in a different goroutine. Not + // setting pgConn.status or closing pgConn.cleanupDone for the same reason. + pgConn.conn.Close() + + copyErrChan <- writeErr + return + } + } + if readErr != nil { + copyErrChan <- readErr + return + } + + select { + case <-abortCopyChan: + return + default: + } } }() - buf := iobufpool.Get(65536) - defer iobufpool.Put(buf) - buf[0] = 'd' - - var readErr, pgErr error - for pgErr == nil { - // Read chunk from r. - var n int - n, readErr = r.Read(buf[5:cap(buf)]) - - // Send chunk to PostgreSQL. - if n > 0 { - buf = buf[0 : n+5] - pgio.SetInt32(buf[1:], int32(n+4)) - - writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(buf) - if writeErr != nil { - pgConn.asyncClose() - return CommandTag{}, err - } - } - - // Abort loop if there was a read error. - if readErr != nil { - break - } - - // Read messages until error or none available. - for pgErr == nil { - msg, err := pgConn.receiveMessage() - if err != nil { - if errors.Is(err, nbconn.ErrWouldBlock) { - break - } - pgConn.asyncClose() + var pgErr error + var copyErr error + for copyErr == nil && pgErr == nil { + select { + case copyErr = <-copyErrChan: + case <-signalMessageChan: + // If pgConn.receiveMessage encounters an error it will call pgConn.asyncClose. But that is a race condition with + // the goroutine. So instead check pgConn.bufferingReceiveErr which will have been set by the signalMessage. If an + // error is found then forcibly close the connection without sending the Terminate message. + if err := pgConn.bufferingReceiveErr; err != nil { + pgConn.status = connStatusClosed + pgConn.conn.Close() + close(pgConn.cleanupDone) return CommandTag{}, normalizeTimeoutError(ctx, err) } + msg, _ := pgConn.receiveMessage() switch msg := msg.(type) { case *pgproto3.ErrorResponse: pgErr = ErrorResponseToPgError(msg) - break + default: + signalMessageChan = pgConn.signalMessage() } } } + close(abortCopyChan) + // Make sure io goroutine finishes before writing. + wg.Wait() - err = pgConn.conn.SetReadDeadline(time.Time{}) - if err != nil { - pgConn.asyncClose() - return CommandTag{}, err - } - nonblocking = false - - if readErr == io.EOF || pgErr != nil { + if copyErr == io.EOF || pgErr != nil { pgConn.frontend.Send(&pgproto3.CopyDone{}) } else { - pgConn.frontend.Send(&pgproto3.CopyFail{Message: readErr.Error()}) + pgConn.frontend.Send(&pgproto3.CopyFail{Message: copyErr.Error()}) } - err = pgConn.frontend.Flush() + err = pgConn.flushWithPotentialWriteReadDeadlock() if err != nil { pgConn.asyncClose() return CommandTag{}, err @@ -1274,7 +1343,6 @@ func (mrr *MultiResultReader) ReadAll() ([]*Result, error) { func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) { msg, err := mrr.pgConn.receiveMessage() - if err != nil { mrr.pgConn.contextWatcher.Unwatch() mrr.err = normalizeTimeoutError(mrr.ctx, err) @@ -1417,7 +1485,8 @@ func (rr *ResultReader) NextRow() bool { } // FieldDescriptions returns the field descriptions for the current result set. The returned slice is only valid until -// the ResultReader is closed. +// the ResultReader is closed. It may return nil (for example, if the query did not return a result set or an error was +// encountered.) func (rr *ResultReader) FieldDescriptions() []FieldDescription { return rr.fieldDescriptions } @@ -1583,6 +1652,8 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) + pgConn.enterPotentialWriteReadDeadlock() + defer pgConn.exitPotentialWriteReadDeadlock() _, err := pgConn.conn.Write(batch.buf) if err != nil { multiResult.closed = true @@ -1611,29 +1682,99 @@ func (pgConn *PgConn) EscapeString(s string) (string, error) { return strings.Replace(s, "'", "''", -1), nil } -// CheckConn checks the underlying connection without writing any bytes. This is currently implemented by reading and -// buffering until the read would block or an error occurs. This can be used to check if the server has closed the -// connection. If this is done immediately before sending a query it reduces the chances a query will be sent that fails +// CheckConn checks the underlying connection without writing any bytes. This is currently implemented by doing a read +// with a very short deadline. This can be useful because a TCP connection can be broken such that a write will appear +// to succeed even though it will never actually reach the server. Reading immediately before a write will detect this +// condition. If this is done immediately before sending a query it reduces the chances a query will be sent that fails // without the client knowing whether the server received it or not. +// +// Deprecated: CheckConn is deprecated in favor of Ping. CheckConn cannot detect all types of broken connections where +// the write would still appear to succeed. Prefer Ping unless on a high latency connection. func (pgConn *PgConn) CheckConn() error { - err := pgConn.conn.BufferReadUntilBlock() - if err != nil && !errors.Is(err, nbconn.ErrWouldBlock) { - return err + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel() + + _, err := pgConn.ReceiveMessage(ctx) + if err != nil { + if !Timeout(err) { + return err + } } + return nil } +// Ping pings the server. This can be useful because a TCP connection can be broken such that a write will appear to +// succeed even though it will never actually reach the server. Pinging immediately before sending a query reduces the +// chances a query will be sent that fails without the client knowing whether the server received it or not. +func (pgConn *PgConn) Ping(ctx context.Context) error { + return pgConn.Exec(ctx, "-- ping").Close() +} + // makeCommandTag makes a CommandTag. It does not retain a reference to buf or buf's underlying memory. func (pgConn *PgConn) makeCommandTag(buf []byte) CommandTag { return CommandTag{s: string(buf)} } +// enterPotentialWriteReadDeadlock must be called before a write that could deadlock if the server is simultaneously +// blocked writing to us. +func (pgConn *PgConn) enterPotentialWriteReadDeadlock() { + // The time to wait is somewhat arbitrary. A Write should only take as long as the syscall and memcpy to the OS + // outbound network buffer unless the buffer is full (which potentially is a block). It needs to be long enough for + // the normal case, but short enough not to kill performance if a block occurs. + // + // In addition, on Windows the default timer resolution is 15.6ms. So setting the timer to less than that is + // ineffective. + if pgConn.slowWriteTimer.Reset(15 * time.Millisecond) { + panic("BUG: slow write timer already active") + } +} + +// exitPotentialWriteReadDeadlock must be called after a call to enterPotentialWriteReadDeadlock. +func (pgConn *PgConn) exitPotentialWriteReadDeadlock() { + // The state of the timer is not relevant upon exiting the potential slow write. It may both + // fire (due to a slow write), or not fire (due to a fast write). + _ = pgConn.slowWriteTimer.Stop() + pgConn.bgReader.Stop() +} + +func (pgConn *PgConn) flushWithPotentialWriteReadDeadlock() error { + pgConn.enterPotentialWriteReadDeadlock() + defer pgConn.exitPotentialWriteReadDeadlock() + err := pgConn.frontend.Flush() + return err +} + +// SyncConn prepares the underlying net.Conn for direct use. PgConn may internally buffer reads or use goroutines for +// background IO. This means that any direct use of the underlying net.Conn may be corrupted if a read is already +// buffered or a read is in progress. SyncConn drains read buffers and stops background IO. In some cases this may +// require sending a ping to the server. ctx can be used to cancel this operation. This should be called before any +// operation that will use the underlying net.Conn directly. e.g. Before Conn() or Hijack(). +// +// This should not be confused with the PostgreSQL protocol Sync message. +func (pgConn *PgConn) SyncConn(ctx context.Context) error { + for i := 0; i < 10; i++ { + if pgConn.bgReader.Status() == bgreader.StatusStopped && pgConn.frontend.ReadBufferLen() == 0 { + return nil + } + + err := pgConn.Ping(ctx) + if err != nil { + return fmt.Errorf("SyncConn: Ping failed while syncing conn: %w", err) + } + } + + // This should never happen. Only way I can imagine this occuring is if the server is constantly sending data such as + // LISTEN/NOTIFY or log notifications such that we never can get an empty buffer. + return errors.New("SyncConn: conn never synchronized") +} + // HijackedConn is the result of hijacking a connection. // // Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning // compatibility. type HijackedConn struct { - Conn nbconn.Conn // the non-blocking wrapper of the underlying TCP or unix domain socket connection + Conn net.Conn PID uint32 // backend pid SecretKey uint32 // key to use to send a cancel query message to the server ParameterStatuses map[string]string // parameters that have been reported by the server @@ -1642,9 +1783,9 @@ type HijackedConn struct { Config *Config } -// Hijack extracts the internal connection data. pgConn must be in an idle state. pgConn is unusable after hijacking. -// Hijacking is typically only useful when using pgconn to establish a connection, but taking complete control of the -// raw connection after that (e.g. a load balancer or proxy). +// Hijack extracts the internal connection data. pgConn must be in an idle state. SyncConn should be called immediately +// before Hijack. pgConn is unusable after hijacking. Hijacking is typically only useful when using pgconn to establish +// a connection, but taking complete control of the raw connection after that (e.g. a load balancer or proxy). // // Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning // compatibility. @@ -1668,6 +1809,8 @@ func (pgConn *PgConn) Hijack() (*HijackedConn, error) { // Construct created a PgConn from an already established connection to a PostgreSQL server. This is the inverse of // PgConn.Hijack. The connection must be in an idle state. // +// hc.Frontend is replaced by a new pgproto3.Frontend built by hc.Config.BuildFrontend. +// // Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning // compatibility. func Construct(hc *HijackedConn) (*PgConn, error) { @@ -1686,6 +1829,10 @@ func Construct(hc *HijackedConn) (*PgConn, error) { } pgConn.contextWatcher = newContextWatcher(pgConn.conn) + pgConn.bgReader = bgreader.New(pgConn.conn) + pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), pgConn.bgReader.Start) + pgConn.slowWriteTimer.Stop() + pgConn.frontend = hc.Config.BuildFrontend(pgConn.bgReader, pgConn.conn) return pgConn, nil } @@ -1808,7 +1955,7 @@ func (p *Pipeline) Flush() error { return errors.New("pipeline closed") } - err := p.conn.frontend.Flush() + err := p.conn.flushWithPotentialWriteReadDeadlock() if err != nil { err = normalizeTimeoutError(p.ctx, err) @@ -1887,7 +2034,6 @@ func (p *Pipeline) GetResults() (results any, err error) { } } - } func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) { diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/backend.go b/vendor/github.com/jackc/pgx/v5/pgproto3/backend.go index 09aeb7c..6db77e4 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/backend.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/backend.go @@ -196,7 +196,7 @@ func (b *Backend) Receive() (FrontendMessage, error) { case AuthTypeCleartextPassword, AuthTypeMD5Password: fallthrough default: - // to maintain backwards compatability + // to maintain backwards compatibility msg = &PasswordMessage{} } case 'Q': @@ -233,11 +233,11 @@ func (b *Backend) Receive() (FrontendMessage, error) { // contextual identification of FrontendMessages. For example, in the // PG message flow documentation for PasswordMessage: // -// Byte1('p') +// Byte1('p') // -// Identifies the message as a password response. Note that this is also used for -// GSSAPI, SSPI and SASL response messages. The exact message type can be deduced from -// the context. +// Identifies the message as a password response. Note that this is also used for +// GSSAPI, SSPI and SASL response messages. The exact message type can be deduced from +// the context. // // Since the Frontend does not know about the state of a backend, it is important // to call SetAuthType() after an authentication request is received by the Frontend. diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/chunkreader.go b/vendor/github.com/jackc/pgx/v5/pgproto3/chunkreader.go index 3c35d0b..fc0fa61 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/chunkreader.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/chunkreader.go @@ -14,7 +14,7 @@ import ( type chunkReader struct { r io.Reader - buf []byte + buf *[]byte rp, wp int // buf read position and write position minBufSize int @@ -45,7 +45,7 @@ func newChunkReader(r io.Reader, minBufSize int) *chunkReader { func (r *chunkReader) Next(n int) (buf []byte, err error) { // Reset the buffer if it is empty if r.rp == r.wp { - if len(r.buf) != r.minBufSize { + if len(*r.buf) != r.minBufSize { iobufpool.Put(r.buf) r.buf = iobufpool.Get(r.minBufSize) } @@ -55,15 +55,15 @@ func (r *chunkReader) Next(n int) (buf []byte, err error) { // n bytes already in buf if (r.wp - r.rp) >= n { - buf = r.buf[r.rp : r.rp+n : r.rp+n] + buf = (*r.buf)[r.rp : r.rp+n : r.rp+n] r.rp += n return buf, err } // buf is smaller than requested number of bytes - if len(r.buf) < n { + if len(*r.buf) < n { bigBuf := iobufpool.Get(n) - r.wp = copy(bigBuf, r.buf[r.rp:r.wp]) + r.wp = copy((*bigBuf), (*r.buf)[r.rp:r.wp]) r.rp = 0 iobufpool.Put(r.buf) r.buf = bigBuf @@ -71,20 +71,20 @@ func (r *chunkReader) Next(n int) (buf []byte, err error) { // buf is large enough, but need to shift filled area to start to make enough contiguous space minReadCount := n - (r.wp - r.rp) - if (len(r.buf) - r.wp) < minReadCount { - r.wp = copy(r.buf, r.buf[r.rp:r.wp]) + if (len(*r.buf) - r.wp) < minReadCount { + r.wp = copy((*r.buf), (*r.buf)[r.rp:r.wp]) r.rp = 0 } // Read at least the required number of bytes from the underlying io.Reader - readBytesCount, err := io.ReadAtLeast(r.r, r.buf[r.wp:], minReadCount) + readBytesCount, err := io.ReadAtLeast(r.r, (*r.buf)[r.wp:], minReadCount) r.wp += readBytesCount // fmt.Println("read", n) if err != nil { return nil, err } - buf = r.buf[r.rp : r.rp+n : r.rp+n] + buf = (*r.buf)[r.rp : r.rp+n : r.rp+n] r.rp += n return buf, nil } diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/frontend.go b/vendor/github.com/jackc/pgx/v5/pgproto3/frontend.go index 83dea96..33c3882 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/frontend.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/frontend.go @@ -361,3 +361,7 @@ func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, er func (f *Frontend) GetAuthType() uint32 { return f.authType } + +func (f *Frontend) ReadBufferLen() int { + return f.cr.wp - f.cr.rp +} diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/trace.go b/vendor/github.com/jackc/pgx/v5/pgproto3/trace.go index c09f68d..6cc7d3e 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/trace.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/trace.go @@ -6,15 +6,18 @@ import ( "io" "strconv" "strings" + "sync" "time" ) // tracer traces the messages send to and from a Backend or Frontend. The format it produces roughly mimics the // format produced by the libpq C function PQtrace. type tracer struct { + TracerOptions + + mux sync.Mutex w io.Writer buf *bytes.Buffer - TracerOptions } // TracerOptions controls tracing behavior. It is roughly equivalent to the libpq function PQsetTraceFlags. @@ -119,278 +122,255 @@ func (t *tracer) traceMessage(sender byte, encodedLen int32, msg Message) { case *Terminate: t.traceTerminate(sender, encodedLen, msg) default: - t.beginTrace(sender, encodedLen, "Unknown") - t.finishTrace() + t.writeTrace(sender, encodedLen, "Unknown", nil) } } func (t *tracer) traceAuthenticationCleartextPassword(sender byte, encodedLen int32, msg *AuthenticationCleartextPassword) { - t.beginTrace(sender, encodedLen, "AuthenticationCleartextPassword") - t.finishTrace() + t.writeTrace(sender, encodedLen, "AuthenticationCleartextPassword", nil) } func (t *tracer) traceAuthenticationGSS(sender byte, encodedLen int32, msg *AuthenticationGSS) { - t.beginTrace(sender, encodedLen, "AuthenticationGSS") - t.finishTrace() + t.writeTrace(sender, encodedLen, "AuthenticationGSS", nil) } func (t *tracer) traceAuthenticationGSSContinue(sender byte, encodedLen int32, msg *AuthenticationGSSContinue) { - t.beginTrace(sender, encodedLen, "AuthenticationGSSContinue") - t.finishTrace() + t.writeTrace(sender, encodedLen, "AuthenticationGSSContinue", nil) } func (t *tracer) traceAuthenticationMD5Password(sender byte, encodedLen int32, msg *AuthenticationMD5Password) { - t.beginTrace(sender, encodedLen, "AuthenticationMD5Password") - t.finishTrace() + t.writeTrace(sender, encodedLen, "AuthenticationMD5Password", nil) } func (t *tracer) traceAuthenticationOk(sender byte, encodedLen int32, msg *AuthenticationOk) { - t.beginTrace(sender, encodedLen, "AuthenticationOk") - t.finishTrace() + t.writeTrace(sender, encodedLen, "AuthenticationOk", nil) } func (t *tracer) traceAuthenticationSASL(sender byte, encodedLen int32, msg *AuthenticationSASL) { - t.beginTrace(sender, encodedLen, "AuthenticationSASL") - t.finishTrace() + t.writeTrace(sender, encodedLen, "AuthenticationSASL", nil) } func (t *tracer) traceAuthenticationSASLContinue(sender byte, encodedLen int32, msg *AuthenticationSASLContinue) { - t.beginTrace(sender, encodedLen, "AuthenticationSASLContinue") - t.finishTrace() + t.writeTrace(sender, encodedLen, "AuthenticationSASLContinue", nil) } func (t *tracer) traceAuthenticationSASLFinal(sender byte, encodedLen int32, msg *AuthenticationSASLFinal) { - t.beginTrace(sender, encodedLen, "AuthenticationSASLFinal") - t.finishTrace() + t.writeTrace(sender, encodedLen, "AuthenticationSASLFinal", nil) } func (t *tracer) traceBackendKeyData(sender byte, encodedLen int32, msg *BackendKeyData) { - t.beginTrace(sender, encodedLen, "BackendKeyData") - if t.RegressMode { - t.buf.WriteString("\t NNNN NNNN") - } else { - fmt.Fprintf(t.buf, "\t %d %d", msg.ProcessID, msg.SecretKey) - } - t.finishTrace() + t.writeTrace(sender, encodedLen, "BackendKeyData", func() { + if t.RegressMode { + t.buf.WriteString("\t NNNN NNNN") + } else { + fmt.Fprintf(t.buf, "\t %d %d", msg.ProcessID, msg.SecretKey) + } + }) } func (t *tracer) traceBind(sender byte, encodedLen int32, msg *Bind) { - t.beginTrace(sender, encodedLen, "Bind") - fmt.Fprintf(t.buf, "\t %s %s %d", traceDoubleQuotedString([]byte(msg.DestinationPortal)), traceDoubleQuotedString([]byte(msg.PreparedStatement)), len(msg.ParameterFormatCodes)) - for _, fc := range msg.ParameterFormatCodes { - fmt.Fprintf(t.buf, " %d", fc) - } - fmt.Fprintf(t.buf, " %d", len(msg.Parameters)) - for _, p := range msg.Parameters { - fmt.Fprintf(t.buf, " %s", traceSingleQuotedString(p)) - } - fmt.Fprintf(t.buf, " %d", len(msg.ResultFormatCodes)) - for _, fc := range msg.ResultFormatCodes { - fmt.Fprintf(t.buf, " %d", fc) - } - t.finishTrace() + t.writeTrace(sender, encodedLen, "Bind", func() { + fmt.Fprintf(t.buf, "\t %s %s %d", traceDoubleQuotedString([]byte(msg.DestinationPortal)), traceDoubleQuotedString([]byte(msg.PreparedStatement)), len(msg.ParameterFormatCodes)) + for _, fc := range msg.ParameterFormatCodes { + fmt.Fprintf(t.buf, " %d", fc) + } + fmt.Fprintf(t.buf, " %d", len(msg.Parameters)) + for _, p := range msg.Parameters { + fmt.Fprintf(t.buf, " %s", traceSingleQuotedString(p)) + } + fmt.Fprintf(t.buf, " %d", len(msg.ResultFormatCodes)) + for _, fc := range msg.ResultFormatCodes { + fmt.Fprintf(t.buf, " %d", fc) + } + }) } func (t *tracer) traceBindComplete(sender byte, encodedLen int32, msg *BindComplete) { - t.beginTrace(sender, encodedLen, "BindComplete") - t.finishTrace() + t.writeTrace(sender, encodedLen, "BindComplete", nil) } func (t *tracer) traceCancelRequest(sender byte, encodedLen int32, msg *CancelRequest) { - t.beginTrace(sender, encodedLen, "CancelRequest") - t.finishTrace() + t.writeTrace(sender, encodedLen, "CancelRequest", nil) } func (t *tracer) traceClose(sender byte, encodedLen int32, msg *Close) { - t.beginTrace(sender, encodedLen, "Close") - t.finishTrace() + t.writeTrace(sender, encodedLen, "Close", nil) } func (t *tracer) traceCloseComplete(sender byte, encodedLen int32, msg *CloseComplete) { - t.beginTrace(sender, encodedLen, "CloseComplete") - t.finishTrace() + t.writeTrace(sender, encodedLen, "CloseComplete", nil) } func (t *tracer) traceCommandComplete(sender byte, encodedLen int32, msg *CommandComplete) { - t.beginTrace(sender, encodedLen, "CommandComplete") - fmt.Fprintf(t.buf, "\t %s", traceDoubleQuotedString(msg.CommandTag)) - t.finishTrace() + t.writeTrace(sender, encodedLen, "CommandComplete", func() { + fmt.Fprintf(t.buf, "\t %s", traceDoubleQuotedString(msg.CommandTag)) + }) } func (t *tracer) traceCopyBothResponse(sender byte, encodedLen int32, msg *CopyBothResponse) { - t.beginTrace(sender, encodedLen, "CopyBothResponse") - t.finishTrace() + t.writeTrace(sender, encodedLen, "CopyBothResponse", nil) } func (t *tracer) traceCopyData(sender byte, encodedLen int32, msg *CopyData) { - t.beginTrace(sender, encodedLen, "CopyData") - t.finishTrace() + t.writeTrace(sender, encodedLen, "CopyData", nil) } func (t *tracer) traceCopyDone(sender byte, encodedLen int32, msg *CopyDone) { - t.beginTrace(sender, encodedLen, "CopyDone") - t.finishTrace() + t.writeTrace(sender, encodedLen, "CopyDone", nil) } func (t *tracer) traceCopyFail(sender byte, encodedLen int32, msg *CopyFail) { - t.beginTrace(sender, encodedLen, "CopyFail") - fmt.Fprintf(t.buf, "\t %s", traceDoubleQuotedString([]byte(msg.Message))) - t.finishTrace() + t.writeTrace(sender, encodedLen, "CopyFail", func() { + fmt.Fprintf(t.buf, "\t %s", traceDoubleQuotedString([]byte(msg.Message))) + }) } func (t *tracer) traceCopyInResponse(sender byte, encodedLen int32, msg *CopyInResponse) { - t.beginTrace(sender, encodedLen, "CopyInResponse") - t.finishTrace() + t.writeTrace(sender, encodedLen, "CopyInResponse", nil) } func (t *tracer) traceCopyOutResponse(sender byte, encodedLen int32, msg *CopyOutResponse) { - t.beginTrace(sender, encodedLen, "CopyOutResponse") - t.finishTrace() + t.writeTrace(sender, encodedLen, "CopyOutResponse", nil) } func (t *tracer) traceDataRow(sender byte, encodedLen int32, msg *DataRow) { - t.beginTrace(sender, encodedLen, "DataRow") - fmt.Fprintf(t.buf, "\t %d", len(msg.Values)) - for _, v := range msg.Values { - if v == nil { - t.buf.WriteString(" -1") - } else { - fmt.Fprintf(t.buf, " %d %s", len(v), traceSingleQuotedString(v)) + t.writeTrace(sender, encodedLen, "DataRow", func() { + fmt.Fprintf(t.buf, "\t %d", len(msg.Values)) + for _, v := range msg.Values { + if v == nil { + t.buf.WriteString(" -1") + } else { + fmt.Fprintf(t.buf, " %d %s", len(v), traceSingleQuotedString(v)) + } } - } - t.finishTrace() + }) } func (t *tracer) traceDescribe(sender byte, encodedLen int32, msg *Describe) { - t.beginTrace(sender, encodedLen, "Describe") - fmt.Fprintf(t.buf, "\t %c %s", msg.ObjectType, traceDoubleQuotedString([]byte(msg.Name))) - t.finishTrace() + t.writeTrace(sender, encodedLen, "Describe", func() { + fmt.Fprintf(t.buf, "\t %c %s", msg.ObjectType, traceDoubleQuotedString([]byte(msg.Name))) + }) } func (t *tracer) traceEmptyQueryResponse(sender byte, encodedLen int32, msg *EmptyQueryResponse) { - t.beginTrace(sender, encodedLen, "EmptyQueryResponse") - t.finishTrace() + t.writeTrace(sender, encodedLen, "EmptyQueryResponse", nil) } func (t *tracer) traceErrorResponse(sender byte, encodedLen int32, msg *ErrorResponse) { - t.beginTrace(sender, encodedLen, "ErrorResponse") - t.finishTrace() + t.writeTrace(sender, encodedLen, "ErrorResponse", nil) } func (t *tracer) TraceQueryute(sender byte, encodedLen int32, msg *Execute) { - t.beginTrace(sender, encodedLen, "Execute") - fmt.Fprintf(t.buf, "\t %s %d", traceDoubleQuotedString([]byte(msg.Portal)), msg.MaxRows) - t.finishTrace() + t.writeTrace(sender, encodedLen, "Execute", func() { + fmt.Fprintf(t.buf, "\t %s %d", traceDoubleQuotedString([]byte(msg.Portal)), msg.MaxRows) + }) } func (t *tracer) traceFlush(sender byte, encodedLen int32, msg *Flush) { - t.beginTrace(sender, encodedLen, "Flush") - t.finishTrace() + t.writeTrace(sender, encodedLen, "Flush", nil) } func (t *tracer) traceFunctionCall(sender byte, encodedLen int32, msg *FunctionCall) { - t.beginTrace(sender, encodedLen, "FunctionCall") - t.finishTrace() + t.writeTrace(sender, encodedLen, "FunctionCall", nil) } func (t *tracer) traceFunctionCallResponse(sender byte, encodedLen int32, msg *FunctionCallResponse) { - t.beginTrace(sender, encodedLen, "FunctionCallResponse") - t.finishTrace() + t.writeTrace(sender, encodedLen, "FunctionCallResponse", nil) } func (t *tracer) traceGSSEncRequest(sender byte, encodedLen int32, msg *GSSEncRequest) { - t.beginTrace(sender, encodedLen, "GSSEncRequest") - t.finishTrace() + t.writeTrace(sender, encodedLen, "GSSEncRequest", nil) } func (t *tracer) traceNoData(sender byte, encodedLen int32, msg *NoData) { - t.beginTrace(sender, encodedLen, "NoData") - t.finishTrace() + t.writeTrace(sender, encodedLen, "NoData", nil) } func (t *tracer) traceNoticeResponse(sender byte, encodedLen int32, msg *NoticeResponse) { - t.beginTrace(sender, encodedLen, "NoticeResponse") - t.finishTrace() + t.writeTrace(sender, encodedLen, "NoticeResponse", nil) } func (t *tracer) traceNotificationResponse(sender byte, encodedLen int32, msg *NotificationResponse) { - t.beginTrace(sender, encodedLen, "NotificationResponse") - fmt.Fprintf(t.buf, "\t %d %s %s", msg.PID, traceDoubleQuotedString([]byte(msg.Channel)), traceDoubleQuotedString([]byte(msg.Payload))) - t.finishTrace() + t.writeTrace(sender, encodedLen, "NotificationResponse", func() { + fmt.Fprintf(t.buf, "\t %d %s %s", msg.PID, traceDoubleQuotedString([]byte(msg.Channel)), traceDoubleQuotedString([]byte(msg.Payload))) + }) } func (t *tracer) traceParameterDescription(sender byte, encodedLen int32, msg *ParameterDescription) { - t.beginTrace(sender, encodedLen, "ParameterDescription") - t.finishTrace() + t.writeTrace(sender, encodedLen, "ParameterDescription", nil) } func (t *tracer) traceParameterStatus(sender byte, encodedLen int32, msg *ParameterStatus) { - t.beginTrace(sender, encodedLen, "ParameterStatus") - fmt.Fprintf(t.buf, "\t %s %s", traceDoubleQuotedString([]byte(msg.Name)), traceDoubleQuotedString([]byte(msg.Value))) - t.finishTrace() + t.writeTrace(sender, encodedLen, "ParameterStatus", func() { + fmt.Fprintf(t.buf, "\t %s %s", traceDoubleQuotedString([]byte(msg.Name)), traceDoubleQuotedString([]byte(msg.Value))) + }) } func (t *tracer) traceParse(sender byte, encodedLen int32, msg *Parse) { - t.beginTrace(sender, encodedLen, "Parse") - fmt.Fprintf(t.buf, "\t %s %s %d", traceDoubleQuotedString([]byte(msg.Name)), traceDoubleQuotedString([]byte(msg.Query)), len(msg.ParameterOIDs)) - for _, oid := range msg.ParameterOIDs { - fmt.Fprintf(t.buf, " %d", oid) - } - t.finishTrace() + t.writeTrace(sender, encodedLen, "Parse", func() { + fmt.Fprintf(t.buf, "\t %s %s %d", traceDoubleQuotedString([]byte(msg.Name)), traceDoubleQuotedString([]byte(msg.Query)), len(msg.ParameterOIDs)) + for _, oid := range msg.ParameterOIDs { + fmt.Fprintf(t.buf, " %d", oid) + } + }) } func (t *tracer) traceParseComplete(sender byte, encodedLen int32, msg *ParseComplete) { - t.beginTrace(sender, encodedLen, "ParseComplete") - t.finishTrace() + t.writeTrace(sender, encodedLen, "ParseComplete", nil) } func (t *tracer) tracePortalSuspended(sender byte, encodedLen int32, msg *PortalSuspended) { - t.beginTrace(sender, encodedLen, "PortalSuspended") - t.finishTrace() + t.writeTrace(sender, encodedLen, "PortalSuspended", nil) } func (t *tracer) traceQuery(sender byte, encodedLen int32, msg *Query) { - t.beginTrace(sender, encodedLen, "Query") - fmt.Fprintf(t.buf, "\t %s", traceDoubleQuotedString([]byte(msg.String))) - t.finishTrace() + t.writeTrace(sender, encodedLen, "Query", func() { + fmt.Fprintf(t.buf, "\t %s", traceDoubleQuotedString([]byte(msg.String))) + }) } func (t *tracer) traceReadyForQuery(sender byte, encodedLen int32, msg *ReadyForQuery) { - t.beginTrace(sender, encodedLen, "ReadyForQuery") - fmt.Fprintf(t.buf, "\t %c", msg.TxStatus) - t.finishTrace() + t.writeTrace(sender, encodedLen, "ReadyForQuery", func() { + fmt.Fprintf(t.buf, "\t %c", msg.TxStatus) + }) } func (t *tracer) traceRowDescription(sender byte, encodedLen int32, msg *RowDescription) { - t.beginTrace(sender, encodedLen, "RowDescription") - fmt.Fprintf(t.buf, "\t %d", len(msg.Fields)) - for _, fd := range msg.Fields { - fmt.Fprintf(t.buf, ` %s %d %d %d %d %d %d`, traceDoubleQuotedString(fd.Name), fd.TableOID, fd.TableAttributeNumber, fd.DataTypeOID, fd.DataTypeSize, fd.TypeModifier, fd.Format) - } - t.finishTrace() + t.writeTrace(sender, encodedLen, "RowDescription", func() { + fmt.Fprintf(t.buf, "\t %d", len(msg.Fields)) + for _, fd := range msg.Fields { + fmt.Fprintf(t.buf, ` %s %d %d %d %d %d %d`, traceDoubleQuotedString(fd.Name), fd.TableOID, fd.TableAttributeNumber, fd.DataTypeOID, fd.DataTypeSize, fd.TypeModifier, fd.Format) + } + }) } func (t *tracer) traceSSLRequest(sender byte, encodedLen int32, msg *SSLRequest) { - t.beginTrace(sender, encodedLen, "SSLRequest") - t.finishTrace() + t.writeTrace(sender, encodedLen, "SSLRequest", nil) } func (t *tracer) traceStartupMessage(sender byte, encodedLen int32, msg *StartupMessage) { - t.beginTrace(sender, encodedLen, "StartupMessage") - t.finishTrace() + t.writeTrace(sender, encodedLen, "StartupMessage", nil) } func (t *tracer) traceSync(sender byte, encodedLen int32, msg *Sync) { - t.beginTrace(sender, encodedLen, "Sync") - t.finishTrace() + t.writeTrace(sender, encodedLen, "Sync", nil) } func (t *tracer) traceTerminate(sender byte, encodedLen int32, msg *Terminate) { - t.beginTrace(sender, encodedLen, "Terminate") - t.finishTrace() + t.writeTrace(sender, encodedLen, "Terminate", nil) } -func (t *tracer) beginTrace(sender byte, encodedLen int32, msgType string) { +func (t *tracer) writeTrace(sender byte, encodedLen int32, msgType string, writeDetails func()) { + t.mux.Lock() + defer t.mux.Unlock() + defer func() { + if t.buf.Cap() > 1024 { + t.buf = &bytes.Buffer{} + } else { + t.buf.Reset() + } + }() + if !t.SuppressTimestamps { now := time.Now() t.buf.WriteString(now.Format("2006-01-02 15:04:05.000000")) @@ -402,17 +382,13 @@ func (t *tracer) beginTrace(sender byte, encodedLen int32, msgType string) { t.buf.WriteString(msgType) t.buf.WriteByte('\t') t.buf.WriteString(strconv.FormatInt(int64(encodedLen), 10)) -} -func (t *tracer) finishTrace() { + if writeDetails != nil { + writeDetails() + } + t.buf.WriteByte('\n') t.buf.WriteTo(t.w) - - if t.buf.Cap() > 1024 { - t.buf = &bytes.Buffer{} - } else { - t.buf.Reset() - } } // traceDoubleQuotedString returns t.buf as a double-quoted string without any escaping. It is roughly equivalent to diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/array.go b/vendor/github.com/jackc/pgx/v5/pgtype/array.go index 0fa4c12..7376195 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/array.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/array.go @@ -5,7 +5,6 @@ import ( "encoding/binary" "fmt" "io" - "reflect" "strconv" "strings" "unicode" @@ -363,38 +362,18 @@ func quoteArrayElement(src string) string { } func isSpace(ch byte) bool { - // see https://github.com/postgres/postgres/blob/REL_12_STABLE/src/backend/parser/scansup.c#L224 - return ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' || ch == '\f' + // see array_isspace: + // https://github.com/postgres/postgres/blob/master/src/backend/utils/adt/arrayfuncs.c + return ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' || ch == '\v' || ch == '\f' } func quoteArrayElementIfNeeded(src string) string { - if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || isSpace(src[0]) || isSpace(src[len(src)-1]) || strings.ContainsAny(src, `{},"\`) { + if src == "" || (len(src) == 4 && strings.EqualFold(src, "null")) || isSpace(src[0]) || isSpace(src[len(src)-1]) || strings.ContainsAny(src, `{},"\`) { return quoteArrayElement(src) } return src } -func findDimensionsFromValue(value reflect.Value, dimensions []ArrayDimension, elementsLength int) ([]ArrayDimension, int, bool) { - switch value.Kind() { - case reflect.Array: - fallthrough - case reflect.Slice: - length := value.Len() - if 0 == elementsLength { - elementsLength = length - } else { - elementsLength *= length - } - dimensions = append(dimensions, ArrayDimension{Length: int32(length), LowerBound: 1}) - for i := 0; i < length; i++ { - if d, l, ok := findDimensionsFromValue(value.Index(i), dimensions, elementsLength); ok { - return d, l, true - } - } - } - return dimensions, elementsLength, true -} - // Array represents a PostgreSQL array for T. It implements the ArrayGetter and ArraySetter interfaces. It preserves // PostgreSQL dimensions and custom lower bounds. Use FlatArray if these are not needed. type Array[T any] struct { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/array_codec.go b/vendor/github.com/jackc/pgx/v5/pgtype/array_codec.go index dae1203..c1863b3 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/array_codec.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/array_codec.go @@ -47,7 +47,16 @@ func (c *ArrayCodec) FormatSupported(format int16) bool { } func (c *ArrayCodec) PreferredFormat() int16 { - return c.ElementType.Codec.PreferredFormat() + // The binary format should always be preferred for arrays if it is supported. Usually, this will happen automatically + // because most types that support binary prefer it. However, text, json, and jsonb support binary but prefer the text + // format. This is because it is simpler for jsonb and PostgreSQL can be significantly faster using the text format + // for text-like data types than binary. However, arrays appear to always be faster in binary. + // + // https://www.postgresql.org/message-id/CAMovtNoHFod2jMAKQjjxv209PCTJx5Kc66anwWvX0mEiaXwgmA%40mail.gmail.com + if c.ElementType.Codec.FormatSupported(BinaryFormatCode) { + return BinaryFormatCode + } + return TextFormatCode } func (c *ArrayCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { @@ -60,7 +69,9 @@ func (c *ArrayCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Enc elementEncodePlan := m.PlanEncode(c.ElementType.OID, format, elementType) if elementEncodePlan == nil { - return nil + if reflect.TypeOf(elementType) != nil { + return nil + } } switch format { @@ -301,7 +312,7 @@ func (c *ArrayCodec) decodeText(m *Map, arrayOID uint32, src []byte, array Array for i, s := range uta.Elements { elem := array.ScanIndex(i) var elemSrc []byte - if s != "NULL" { + if s != "NULL" || uta.Quoted[i] { elemSrc = []byte(s) } diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/bool.go b/vendor/github.com/jackc/pgx/v5/pgtype/bool.go index e7be27e..71caffa 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/bool.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/bool.go @@ -1,10 +1,12 @@ package pgtype import ( + "bytes" "database/sql/driver" "encoding/json" "fmt" "strconv" + "strings" ) type BoolScanner interface { @@ -264,8 +266,8 @@ func (scanPlanTextAnyToBool) Scan(src []byte, dst any) error { return fmt.Errorf("cannot scan NULL into %T", dst) } - if len(src) != 1 { - return fmt.Errorf("invalid length for bool: %v", len(src)) + if len(src) == 0 { + return fmt.Errorf("cannot scan empty string into %T", dst) } p, ok := (dst).(*bool) @@ -273,7 +275,12 @@ func (scanPlanTextAnyToBool) Scan(src []byte, dst any) error { return ErrScanTargetTypeChanged } - *p = src[0] == 't' + v, err := planTextToBool(src) + if err != nil { + return err + } + + *p = v return nil } @@ -309,9 +316,28 @@ func (scanPlanTextAnyToBoolScanner) Scan(src []byte, dst any) error { return s.ScanBool(Bool{}) } - if len(src) != 1 { - return fmt.Errorf("invalid length for bool: %v", len(src)) + if len(src) == 0 { + return fmt.Errorf("cannot scan empty string into %T", dst) } - return s.ScanBool(Bool{Bool: src[0] == 't', Valid: true}) + v, err := planTextToBool(src) + if err != nil { + return err + } + + return s.ScanBool(Bool{Bool: v, Valid: true}) +} + +// https://www.postgresql.org/docs/11/datatype-boolean.html +func planTextToBool(src []byte) (bool, error) { + s := string(bytes.ToLower(bytes.TrimSpace(src))) + + switch { + case strings.HasPrefix("true", s), strings.HasPrefix("yes", s), s == "on", s == "1": + return true, nil + case strings.HasPrefix("false", s), strings.HasPrefix("no", s), strings.HasPrefix("off", s), s == "0": + return false, nil + default: + return false, fmt.Errorf("unknown boolean string representation %q", src) + } } diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/builtin_wrappers.go b/vendor/github.com/jackc/pgx/v5/pgtype/builtin_wrappers.go index e34dd57..8bf367c 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/builtin_wrappers.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/builtin_wrappers.go @@ -910,3 +910,43 @@ func (a *anyMultiDimSliceArray) ScanIndexType() any { } return reflect.New(lowestSliceType.Elem()).Interface() } + +type anyArrayArrayReflect struct { + array reflect.Value +} + +func (a anyArrayArrayReflect) Dimensions() []ArrayDimension { + return []ArrayDimension{{Length: int32(a.array.Len()), LowerBound: 1}} +} + +func (a anyArrayArrayReflect) Index(i int) any { + return a.array.Index(i).Interface() +} + +func (a anyArrayArrayReflect) IndexType() any { + return reflect.New(a.array.Type().Elem()).Elem().Interface() +} + +func (a *anyArrayArrayReflect) SetDimensions(dimensions []ArrayDimension) error { + if dimensions == nil { + return fmt.Errorf("anyArrayArrayReflect: cannot scan NULL into %v", a.array.Type().String()) + } + + if len(dimensions) != 1 { + return fmt.Errorf("anyArrayArrayReflect: cannot scan multi-dimensional array into %v", a.array.Type().String()) + } + + if int(dimensions[0].Length) != a.array.Len() { + return fmt.Errorf("anyArrayArrayReflect: cannot scan array with length %v into %v", dimensions[0].Length, a.array.Type().String()) + } + + return nil +} + +func (a *anyArrayArrayReflect) ScanIndex(i int) any { + return a.array.Index(i).Addr().Interface() +} + +func (a *anyArrayArrayReflect) ScanIndexType() any { + return reflect.New(a.array.Type().Elem()).Interface() +} diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/bytea.go b/vendor/github.com/jackc/pgx/v5/pgtype/bytea.go index 2e06767..a247705 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/bytea.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/bytea.go @@ -238,7 +238,7 @@ func decodeHexBytea(src []byte) ([]byte, error) { } func (c ByteaCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { - return codecDecodeToTextFormat(c, m, oid, format, src) + return c.DecodeValue(m, oid, format, src) } func (c ByteaCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/convert.go b/vendor/github.com/jackc/pgx/v5/pgtype/convert.go index 8a2afbe..8a9cee9 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/convert.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/convert.go @@ -1,377 +1,9 @@ package pgtype import ( - "database/sql" - "fmt" - "math" "reflect" - "time" ) -const ( - maxUint = ^uint(0) - maxInt = int(maxUint >> 1) - minInt = -maxInt - 1 -) - -// underlyingNumberType gets the underlying type that can be converted to Int2, Int4, Int8, Float4, or Float8 -func underlyingNumberType(val any) (any, bool) { - refVal := reflect.ValueOf(val) - - switch refVal.Kind() { - case reflect.Ptr: - if refVal.IsNil() { - return nil, false - } - convVal := refVal.Elem().Interface() - return convVal, true - case reflect.Int: - convVal := int(refVal.Int()) - return convVal, reflect.TypeOf(convVal) != refVal.Type() - case reflect.Int8: - convVal := int8(refVal.Int()) - return convVal, reflect.TypeOf(convVal) != refVal.Type() - case reflect.Int16: - convVal := int16(refVal.Int()) - return convVal, reflect.TypeOf(convVal) != refVal.Type() - case reflect.Int32: - convVal := int32(refVal.Int()) - return convVal, reflect.TypeOf(convVal) != refVal.Type() - case reflect.Int64: - convVal := int64(refVal.Int()) - return convVal, reflect.TypeOf(convVal) != refVal.Type() - case reflect.Uint: - convVal := uint(refVal.Uint()) - return convVal, reflect.TypeOf(convVal) != refVal.Type() - case reflect.Uint8: - convVal := uint8(refVal.Uint()) - return convVal, reflect.TypeOf(convVal) != refVal.Type() - case reflect.Uint16: - convVal := uint16(refVal.Uint()) - return convVal, reflect.TypeOf(convVal) != refVal.Type() - case reflect.Uint32: - convVal := uint32(refVal.Uint()) - return convVal, reflect.TypeOf(convVal) != refVal.Type() - case reflect.Uint64: - convVal := uint64(refVal.Uint()) - return convVal, reflect.TypeOf(convVal) != refVal.Type() - case reflect.Float32: - convVal := float32(refVal.Float()) - return convVal, reflect.TypeOf(convVal) != refVal.Type() - case reflect.Float64: - convVal := refVal.Float() - return convVal, reflect.TypeOf(convVal) != refVal.Type() - case reflect.String: - convVal := refVal.String() - return convVal, reflect.TypeOf(convVal) != refVal.Type() - } - - return nil, false -} - -// underlyingBoolType gets the underlying type that can be converted to Bool -func underlyingBoolType(val any) (any, bool) { - refVal := reflect.ValueOf(val) - - switch refVal.Kind() { - case reflect.Ptr: - if refVal.IsNil() { - return nil, false - } - convVal := refVal.Elem().Interface() - return convVal, true - case reflect.Bool: - convVal := refVal.Bool() - return convVal, reflect.TypeOf(convVal) != refVal.Type() - } - - return nil, false -} - -// underlyingBytesType gets the underlying type that can be converted to []byte -func underlyingBytesType(val any) (any, bool) { - refVal := reflect.ValueOf(val) - - switch refVal.Kind() { - case reflect.Ptr: - if refVal.IsNil() { - return nil, false - } - convVal := refVal.Elem().Interface() - return convVal, true - case reflect.Slice: - if refVal.Type().Elem().Kind() == reflect.Uint8 { - convVal := refVal.Bytes() - return convVal, reflect.TypeOf(convVal) != refVal.Type() - } - } - - return nil, false -} - -// underlyingStringType gets the underlying type that can be converted to String -func underlyingStringType(val any) (any, bool) { - refVal := reflect.ValueOf(val) - - switch refVal.Kind() { - case reflect.Ptr: - if refVal.IsNil() { - return nil, false - } - convVal := refVal.Elem().Interface() - return convVal, true - case reflect.String: - convVal := refVal.String() - return convVal, reflect.TypeOf(convVal) != refVal.Type() - } - - return nil, false -} - -// underlyingPtrType dereferences a pointer -func underlyingPtrType(val any) (any, bool) { - refVal := reflect.ValueOf(val) - - switch refVal.Kind() { - case reflect.Ptr: - if refVal.IsNil() { - return nil, false - } - convVal := refVal.Elem().Interface() - return convVal, true - } - - return nil, false -} - -// underlyingTimeType gets the underlying type that can be converted to time.Time -func underlyingTimeType(val any) (any, bool) { - refVal := reflect.ValueOf(val) - - switch refVal.Kind() { - case reflect.Ptr: - if refVal.IsNil() { - return nil, false - } - convVal := refVal.Elem().Interface() - return convVal, true - } - - timeType := reflect.TypeOf(time.Time{}) - if refVal.Type().ConvertibleTo(timeType) { - return refVal.Convert(timeType).Interface(), true - } - - return nil, false -} - -// underlyingUUIDType gets the underlying type that can be converted to [16]byte -func underlyingUUIDType(val any) (any, bool) { - refVal := reflect.ValueOf(val) - - switch refVal.Kind() { - case reflect.Ptr: - if refVal.IsNil() { - return time.Time{}, false - } - convVal := refVal.Elem().Interface() - return convVal, true - } - - uuidType := reflect.TypeOf([16]byte{}) - if refVal.Type().ConvertibleTo(uuidType) { - return refVal.Convert(uuidType).Interface(), true - } - - return nil, false -} - -// underlyingSliceType gets the underlying slice type -func underlyingSliceType(val any) (any, bool) { - refVal := reflect.ValueOf(val) - - switch refVal.Kind() { - case reflect.Ptr: - if refVal.IsNil() { - return nil, false - } - convVal := refVal.Elem().Interface() - return convVal, true - case reflect.Slice: - baseSliceType := reflect.SliceOf(refVal.Type().Elem()) - if refVal.Type().ConvertibleTo(baseSliceType) { - convVal := refVal.Convert(baseSliceType) - return convVal.Interface(), reflect.TypeOf(convVal.Interface()) != refVal.Type() - } - } - - return nil, false -} - -func int64AssignTo(srcVal int64, srcValid bool, dst any) error { - if srcValid { - switch v := dst.(type) { - case *int: - if srcVal < int64(minInt) { - return fmt.Errorf("%d is less than minimum value for int", srcVal) - } else if srcVal > int64(maxInt) { - return fmt.Errorf("%d is greater than maximum value for int", srcVal) - } - *v = int(srcVal) - case *int8: - if srcVal < math.MinInt8 { - return fmt.Errorf("%d is less than minimum value for int8", srcVal) - } else if srcVal > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for int8", srcVal) - } - *v = int8(srcVal) - case *int16: - if srcVal < math.MinInt16 { - return fmt.Errorf("%d is less than minimum value for int16", srcVal) - } else if srcVal > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for int16", srcVal) - } - *v = int16(srcVal) - case *int32: - if srcVal < math.MinInt32 { - return fmt.Errorf("%d is less than minimum value for int32", srcVal) - } else if srcVal > math.MaxInt32 { - return fmt.Errorf("%d is greater than maximum value for int32", srcVal) - } - *v = int32(srcVal) - case *int64: - if srcVal < math.MinInt64 { - return fmt.Errorf("%d is less than minimum value for int64", srcVal) - } else if srcVal > math.MaxInt64 { - return fmt.Errorf("%d is greater than maximum value for int64", srcVal) - } - *v = int64(srcVal) - case *uint: - if srcVal < 0 { - return fmt.Errorf("%d is less than zero for uint", srcVal) - } else if uint64(srcVal) > uint64(maxUint) { - return fmt.Errorf("%d is greater than maximum value for uint", srcVal) - } - *v = uint(srcVal) - case *uint8: - if srcVal < 0 { - return fmt.Errorf("%d is less than zero for uint8", srcVal) - } else if srcVal > math.MaxUint8 { - return fmt.Errorf("%d is greater than maximum value for uint8", srcVal) - } - *v = uint8(srcVal) - case *uint16: - if srcVal < 0 { - return fmt.Errorf("%d is less than zero for uint32", srcVal) - } else if srcVal > math.MaxUint16 { - return fmt.Errorf("%d is greater than maximum value for uint16", srcVal) - } - *v = uint16(srcVal) - case *uint32: - if srcVal < 0 { - return fmt.Errorf("%d is less than zero for uint32", srcVal) - } else if srcVal > math.MaxUint32 { - return fmt.Errorf("%d is greater than maximum value for uint32", srcVal) - } - *v = uint32(srcVal) - case *uint64: - if srcVal < 0 { - return fmt.Errorf("%d is less than zero for uint64", srcVal) - } - *v = uint64(srcVal) - case sql.Scanner: - return v.Scan(srcVal) - default: - if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { - el := v.Elem() - switch el.Kind() { - // if dst is a pointer to pointer, strip the pointer and try again - case reflect.Ptr: - if el.IsNil() { - // allocate destination - el.Set(reflect.New(el.Type().Elem())) - } - return int64AssignTo(srcVal, srcValid, el.Interface()) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - if el.OverflowInt(int64(srcVal)) { - return fmt.Errorf("cannot put %d into %T", srcVal, dst) - } - el.SetInt(int64(srcVal)) - return nil - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - if srcVal < 0 { - return fmt.Errorf("%d is less than zero for %T", srcVal, dst) - } - if el.OverflowUint(uint64(srcVal)) { - return fmt.Errorf("cannot put %d into %T", srcVal, dst) - } - el.SetUint(uint64(srcVal)) - return nil - } - } - return fmt.Errorf("cannot assign %v into %T", srcVal, dst) - } - return nil - } - - // if dst is a pointer to pointer and srcStatus is not Valid, nil it out - if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { - el := v.Elem() - if el.Kind() == reflect.Ptr { - el.Set(reflect.Zero(el.Type())) - return nil - } - } - - return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcValid, dst) -} - -func float64AssignTo(srcVal float64, srcValid bool, dst any) error { - if srcValid { - switch v := dst.(type) { - case *float32: - *v = float32(srcVal) - case *float64: - *v = srcVal - default: - if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { - el := v.Elem() - switch el.Kind() { - // if dst is a type alias of a float32 or 64, set dst val - case reflect.Float32, reflect.Float64: - el.SetFloat(srcVal) - return nil - // if dst is a pointer to pointer, strip the pointer and try again - case reflect.Ptr: - if el.IsNil() { - // allocate destination - el.Set(reflect.New(el.Type().Elem())) - } - return float64AssignTo(srcVal, srcValid, el.Interface()) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - i64 := int64(srcVal) - if float64(i64) == srcVal { - return int64AssignTo(i64, srcValid, dst) - } - } - } - return fmt.Errorf("cannot assign %v into %T", srcVal, dst) - } - return nil - } - - // if dst is a pointer to pointer and srcStatus is not Valid, nil it out - if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { - el := v.Elem() - if el.Kind() == reflect.Ptr { - el.Set(reflect.Zero(el.Type())) - return nil - } - } - - return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcValid, dst) -} - func NullAssignTo(dst any) error { dstPtr := reflect.ValueOf(dst) diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/doc.go b/vendor/github.com/jackc/pgx/v5/pgtype/doc.go index 834aa0b..6612c89 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/doc.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/doc.go @@ -57,27 +57,7 @@ JSON Support pgtype automatically marshals and unmarshals data from json and jsonb PostgreSQL types. -Array Support - -ArrayCodec implements support for arrays. If pgtype supports type T then it can easily support []T by registering an -ArrayCodec for the appropriate PostgreSQL OID. In addition, Array[T] type can support multi-dimensional arrays. - -Composite Support - -CompositeCodec implements support for PostgreSQL composite types. Go structs can be scanned into if the public fields of -the struct are in the exact order and type of the PostgreSQL type or by implementing CompositeIndexScanner and -CompositeIndexGetter. - -Enum Support - -PostgreSQL enums can usually be treated as text. However, EnumCodec implements support for interning strings which can -reduce memory usage. - -Array, Composite, and Enum Type Registration - -Array, composite, and enum types can be easily registered from a pgx.Conn with the LoadType method. - -Extending Existing Type Support +Extending Existing PostgreSQL Type Support Generally, all Codecs will support interfaces that can be implemented to enable scanning and encoding. For example, PointCodec can use any Go type that implements the PointScanner and PointValuer interfaces. So rather than use @@ -90,11 +70,58 @@ pgx support such as github.com/shopspring/decimal. These types can be registered logic. See https://github.com/jackc/pgx-shopspring-decimal and https://github.com/jackc/pgx-gofrs-uuid for a example integrations. -Entirely New Type Support +New PostgreSQL Type Support -If the PostgreSQL type is not already supported then an OID / Codec mapping can be registered with Map.RegisterType. -There is no difference between a Codec defined and registered by the application and a Codec built in to pgtype. See any -of the Codecs in pgtype for Codec examples and for examples of type registration. +pgtype uses the PostgreSQL OID to determine how to encode or decode a value. pgtype supports array, composite, domain, +and enum types. However, any type created in PostgreSQL with CREATE TYPE will receive a new OID. This means that the OID +of each new PostgreSQL type must be registered for pgtype to handle values of that type with the correct Codec. + +The pgx.Conn LoadType method can return a *Type for array, composite, domain, and enum types by inspecting the database +metadata. This *Type can then be registered with Map.RegisterType. + +For example, the following function could be called after a connection is established: + + func RegisterDataTypes(ctx context.Context, conn *pgx.Conn) error { + dataTypeNames := []string{ + "foo", + "_foo", + "bar", + "_bar", + } + + for _, typeName := range dataTypeNames { + dataType, err := conn.LoadType(ctx, typeName) + if err != nil { + return err + } + conn.TypeMap().RegisterType(dataType) + } + + return nil + } + +A type cannot be registered unless all types it depends on are already registered. e.g. An array type cannot be +registered until its element type is registered. + +ArrayCodec implements support for arrays. If pgtype supports type T then it can easily support []T by registering an +ArrayCodec for the appropriate PostgreSQL OID. In addition, Array[T] type can support multi-dimensional arrays. + +CompositeCodec implements support for PostgreSQL composite types. Go structs can be scanned into if the public fields of +the struct are in the exact order and type of the PostgreSQL type or by implementing CompositeIndexScanner and +CompositeIndexGetter. + +Domain types are treated as their underlying type if the underlying type and the domain type are registered. + +PostgreSQL enums can usually be treated as text. However, EnumCodec implements support for interning strings which can +reduce memory usage. + +While pgtype will often still work with unregistered types it is highly recommended that all types be registered due to +an improvement in performance and the elimination of certain edge cases. + +If an entirely new PostgreSQL type (e.g. PostGIS types) is used then the application or a library can create a new +Codec. Then the OID / Codec mapping can be registered with Map.RegisterType. There is no difference between a Codec +defined and registered by the application and a Codec built in to pgtype. See any of the Codecs in pgtype for Codec +examples and for examples of type registration. Encoding Unknown Types diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/hstore.go b/vendor/github.com/jackc/pgx/v5/pgtype/hstore.go index 4743643..2f34f4c 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/hstore.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/hstore.go @@ -1,14 +1,11 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" "errors" "fmt" "strings" - "unicode" - "unicode/utf8" "github.com/jackc/pgx/v5/internal/pgio" ) @@ -43,7 +40,7 @@ func (h *Hstore) Scan(src any) error { switch src := src.(type) { case string: - return scanPlanTextAnyToHstoreScanner{}.Scan([]byte(src), h) + return scanPlanTextAnyToHstoreScanner{}.scanString(src, h) } return fmt.Errorf("cannot scan %T", src) @@ -124,8 +121,15 @@ func (encodePlanHstoreCodecText) Encode(value any, buf []byte) (newBuf []byte, e return nil, err } - if hstore == nil { - return nil, nil + if len(hstore) == 0 { + // distinguish between empty and nil: Not strictly required by Postgres, since its protocol + // explicitly marks NULL column values separately. However, the Binary codec does this, and + // this means we can "round trip" Encode and Scan without data loss. + // nil: []byte(nil); empty: []byte{} + if hstore == nil { + return nil, nil + } + return []byte{}, nil } firstPair := true @@ -134,16 +138,23 @@ func (encodePlanHstoreCodecText) Encode(value any, buf []byte) (newBuf []byte, e if firstPair { firstPair = false } else { - buf = append(buf, ',') + buf = append(buf, ',', ' ') } - buf = append(buf, quoteHstoreElementIfNeeded(k)...) + // unconditionally quote hstore keys/values like Postgres does + // this avoids a Mac OS X Postgres hstore parsing bug: + // https://www.postgresql.org/message-id/CA%2BHWA9awUW0%2BRV_gO9r1ABZwGoZxPztcJxPy8vMFSTbTfi4jig%40mail.gmail.com + buf = append(buf, '"') + buf = append(buf, quoteArrayReplacer.Replace(k)...) + buf = append(buf, '"') buf = append(buf, "=>"...) if v == nil { buf = append(buf, "NULL"...) } else { - buf = append(buf, quoteHstoreElementIfNeeded(*v)...) + buf = append(buf, '"') + buf = append(buf, quoteArrayReplacer.Replace(*v)...) + buf = append(buf, '"') } } @@ -174,25 +185,28 @@ func (scanPlanBinaryHstoreToHstoreScanner) Scan(src []byte, dst any) error { scanner := (dst).(HstoreScanner) if src == nil { - return scanner.ScanHstore(Hstore{}) + return scanner.ScanHstore(Hstore(nil)) } rp := 0 - if len(src[rp:]) < 4 { + const uint32Len = 4 + if len(src[rp:]) < uint32Len { return fmt.Errorf("hstore incomplete %v", src) } pairCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 + rp += uint32Len hstore := make(Hstore, pairCount) + // one allocation for all *string, rather than one per string, just like text parsing + valueStrings := make([]string, pairCount) for i := 0; i < pairCount; i++ { - if len(src[rp:]) < 4 { + if len(src[rp:]) < uint32Len { return fmt.Errorf("hstore incomplete %v", src) } keyLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 + rp += uint32Len if len(src[rp:]) < keyLen { return fmt.Errorf("hstore incomplete %v", src) @@ -200,26 +214,17 @@ func (scanPlanBinaryHstoreToHstoreScanner) Scan(src []byte, dst any) error { key := string(src[rp : rp+keyLen]) rp += keyLen - if len(src[rp:]) < 4 { + if len(src[rp:]) < uint32Len { return fmt.Errorf("hstore incomplete %v", src) } valueLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) rp += 4 - var valueBuf []byte if valueLen >= 0 { - valueBuf = src[rp : rp+valueLen] + valueStrings[i] = string(src[rp : rp+valueLen]) rp += valueLen - } - var value Text - err := scanPlanTextAnyToTextScanner{}.Scan(valueBuf, &value) - if err != nil { - return err - } - - if value.Valid { - hstore[key] = &value.String + hstore[key] = &valueStrings[i] } else { hstore[key] = nil } @@ -230,28 +235,22 @@ func (scanPlanBinaryHstoreToHstoreScanner) Scan(src []byte, dst any) error { type scanPlanTextAnyToHstoreScanner struct{} -func (scanPlanTextAnyToHstoreScanner) Scan(src []byte, dst any) error { +func (s scanPlanTextAnyToHstoreScanner) Scan(src []byte, dst any) error { scanner := (dst).(HstoreScanner) if src == nil { - return scanner.ScanHstore(Hstore{}) + return scanner.ScanHstore(Hstore(nil)) } + return s.scanString(string(src), scanner) +} - keys, values, err := parseHstore(string(src)) +// scanString does not return nil hstore values because string cannot be nil. +func (scanPlanTextAnyToHstoreScanner) scanString(src string, scanner HstoreScanner) error { + hstore, err := parseHstore(src) if err != nil { return err } - - m := make(Hstore, len(keys)) - for i := range keys { - if values[i].Valid { - m[keys[i]] = &values[i].String - } else { - m[keys[i]] = nil - } - } - - return scanner.ScanHstore(m) + return scanner.ScanHstore(hstore) } func (c HstoreCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { @@ -271,191 +270,217 @@ func (c HstoreCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) ( return hstore, nil } -var quoteHstoreReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) - -func quoteHstoreElement(src string) string { - return `"` + quoteArrayReplacer.Replace(src) + `"` -} - -func quoteHstoreElementIfNeeded(src string) string { - if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || strings.ContainsAny(src, ` {},"\=>`) { - return quoteArrayElement(src) - } - return src -} - -const ( - hsPre = iota - hsKey - hsSep - hsVal - hsNul - hsNext -) - type hstoreParser struct { - str string - pos int + str string + pos int + nextBackslash int } func newHSP(in string) *hstoreParser { return &hstoreParser{ - pos: 0, - str: in, + pos: 0, + str: in, + nextBackslash: strings.IndexByte(in, '\\'), } } -func (p *hstoreParser) Consume() (r rune, end bool) { +func (p *hstoreParser) atEnd() bool { + return p.pos >= len(p.str) +} + +// consume returns the next byte of the string, or end if the string is done. +func (p *hstoreParser) consume() (b byte, end bool) { if p.pos >= len(p.str) { - end = true - return + return 0, true } - r, w := utf8.DecodeRuneInString(p.str[p.pos:]) - p.pos += w - return + b = p.str[p.pos] + p.pos++ + return b, false } -func (p *hstoreParser) Peek() (r rune, end bool) { - if p.pos >= len(p.str) { - end = true - return - } - r, _ = utf8.DecodeRuneInString(p.str[p.pos:]) - return +func unexpectedByteErr(actualB byte, expectedB byte) error { + return fmt.Errorf("expected '%c' ('%#v'); found '%c' ('%#v')", expectedB, expectedB, actualB, actualB) } -// parseHstore parses the string representation of an hstore column (the same -// you would get from an ordinary SELECT) into two slices of keys and values. it -// is used internally in the default parsing of hstores. -func parseHstore(s string) (k []string, v []Text, err error) { - if s == "" { - return +// consumeExpectedByte consumes expectedB from the string, or returns an error. +func (p *hstoreParser) consumeExpectedByte(expectedB byte) error { + nextB, end := p.consume() + if end { + return fmt.Errorf("expected '%c' ('%#v'); found end", expectedB, expectedB) + } + if nextB != expectedB { + return unexpectedByteErr(nextB, expectedB) + } + return nil +} + +// consumeExpected2 consumes two expected bytes or returns an error. +// This was a bit faster than using a string argument (better inlining? Not sure). +func (p *hstoreParser) consumeExpected2(one byte, two byte) error { + if p.pos+2 > len(p.str) { + return errors.New("unexpected end of string") + } + if p.str[p.pos] != one { + return unexpectedByteErr(p.str[p.pos], one) + } + if p.str[p.pos+1] != two { + return unexpectedByteErr(p.str[p.pos+1], two) + } + p.pos += 2 + return nil +} + +var errEOSInQuoted = errors.New(`found end before closing double-quote ('"')`) + +// consumeDoubleQuoted consumes a double-quoted string from p. The double quote must have been +// parsed already. This copies the string from the backing string so it can be garbage collected. +func (p *hstoreParser) consumeDoubleQuoted() (string, error) { + // fast path: assume most keys/values do not contain escapes + nextDoubleQuote := strings.IndexByte(p.str[p.pos:], '"') + if nextDoubleQuote == -1 { + return "", errEOSInQuoted + } + nextDoubleQuote += p.pos + if p.nextBackslash == -1 || p.nextBackslash > nextDoubleQuote { + // clone the string from the source string to ensure it can be garbage collected separately + // TODO: use strings.Clone on Go 1.20; this could get optimized away + s := strings.Clone(p.str[p.pos:nextDoubleQuote]) + p.pos = nextDoubleQuote + 1 + return s, nil } - buf := bytes.Buffer{} - keys := []string{} - values := []Text{} + // slow path: string contains escapes + s, err := p.consumeDoubleQuotedWithEscapes(p.nextBackslash) + p.nextBackslash = strings.IndexByte(p.str[p.pos:], '\\') + if p.nextBackslash != -1 { + p.nextBackslash += p.pos + } + return s, err +} + +// consumeDoubleQuotedWithEscapes consumes a double-quoted string containing escapes, starting +// at p.pos, and with the first backslash at firstBackslash. This copies the string so it can be +// garbage collected separately. +func (p *hstoreParser) consumeDoubleQuotedWithEscapes(firstBackslash int) (string, error) { + // copy the prefix that does not contain backslashes + var builder strings.Builder + builder.WriteString(p.str[p.pos:firstBackslash]) + + // skip to the backslash + p.pos = firstBackslash + + // copy bytes until the end, unescaping backslashes + for { + nextB, end := p.consume() + if end { + return "", errEOSInQuoted + } else if nextB == '"' { + break + } else if nextB == '\\' { + // escape: skip the backslash and copy the char + nextB, end = p.consume() + if end { + return "", errEOSInQuoted + } + if !(nextB == '\\' || nextB == '"') { + return "", fmt.Errorf("unexpected escape in quoted string: found '%#v'", nextB) + } + builder.WriteByte(nextB) + } else { + // normal byte: copy it + builder.WriteByte(nextB) + } + } + return builder.String(), nil +} + +// consumePairSeparator consumes the Hstore pair separator ", " or returns an error. +func (p *hstoreParser) consumePairSeparator() error { + return p.consumeExpected2(',', ' ') +} + +// consumeKVSeparator consumes the Hstore key/value separator "=>" or returns an error. +func (p *hstoreParser) consumeKVSeparator() error { + return p.consumeExpected2('=', '>') +} + +// consumeDoubleQuotedOrNull consumes the Hstore key/value separator "=>" or returns an error. +func (p *hstoreParser) consumeDoubleQuotedOrNull() (Text, error) { + // peek at the next byte + if p.atEnd() { + return Text{}, errors.New("found end instead of value") + } + next := p.str[p.pos] + if next == 'N' { + // must be the exact string NULL: use consumeExpected2 twice + err := p.consumeExpected2('N', 'U') + if err != nil { + return Text{}, err + } + err = p.consumeExpected2('L', 'L') + if err != nil { + return Text{}, err + } + return Text{String: "", Valid: false}, nil + } else if next != '"' { + return Text{}, unexpectedByteErr(next, '"') + } + + // skip the double quote + p.pos += 1 + s, err := p.consumeDoubleQuoted() + if err != nil { + return Text{}, err + } + return Text{String: s, Valid: true}, nil +} + +func parseHstore(s string) (Hstore, error) { p := newHSP(s) - r, end := p.Consume() - state := hsPre - - for !end { - switch state { - case hsPre: - if r == '"' { - state = hsKey - } else { - err = errors.New("String does not begin with \"") - } - case hsKey: - switch r { - case '"': //End of the key - keys = append(keys, buf.String()) - buf = bytes.Buffer{} - state = hsSep - case '\\': //Potential escaped character - n, end := p.Consume() - switch { - case end: - err = errors.New("Found EOS in key, expecting character or \"") - case n == '"', n == '\\': - buf.WriteRune(n) - default: - buf.WriteRune(r) - buf.WriteRune(n) - } - default: //Any other character - buf.WriteRune(r) - } - case hsSep: - if r == '=' { - r, end = p.Consume() - switch { - case end: - err = errors.New("Found EOS after '=', expecting '>'") - case r == '>': - r, end = p.Consume() - switch { - case end: - err = errors.New("Found EOS after '=>', expecting '\"' or 'NULL'") - case r == '"': - state = hsVal - case r == 'N': - state = hsNul - default: - err = fmt.Errorf("Invalid character '%c' after '=>', expecting '\"' or 'NULL'", r) - } - default: - err = fmt.Errorf("Invalid character after '=', expecting '>'") - } - } else { - err = fmt.Errorf("Invalid character '%c' after value, expecting '='", r) - } - case hsVal: - switch r { - case '"': //End of the value - values = append(values, Text{String: buf.String(), Valid: true}) - buf = bytes.Buffer{} - state = hsNext - case '\\': //Potential escaped character - n, end := p.Consume() - switch { - case end: - err = errors.New("Found EOS in key, expecting character or \"") - case n == '"', n == '\\': - buf.WriteRune(n) - default: - buf.WriteRune(r) - buf.WriteRune(n) - } - default: //Any other character - buf.WriteRune(r) - } - case hsNul: - nulBuf := make([]rune, 3) - nulBuf[0] = r - for i := 1; i < 3; i++ { - r, end = p.Consume() - if end { - err = errors.New("Found EOS in NULL value") - return - } - nulBuf[i] = r - } - if nulBuf[0] == 'U' && nulBuf[1] == 'L' && nulBuf[2] == 'L' { - values = append(values, Text{}) - state = hsNext - } else { - err = fmt.Errorf("Invalid NULL value: 'N%s'", string(nulBuf)) - } - case hsNext: - if r == ',' { - r, end = p.Consume() - switch { - case end: - err = errors.New("Found EOS after ',', expcting space") - case (unicode.IsSpace(r)): - r, end = p.Consume() - state = hsKey - default: - err = fmt.Errorf("Invalid character '%c' after ', ', expecting \"", r) - } - } else { - err = fmt.Errorf("Invalid character '%c' after value, expecting ','", r) + // This is an over-estimate of the number of key/value pairs. Use '>' because I am guessing it + // is less likely to occur in keys/values than '=' or ','. + numPairsEstimate := strings.Count(s, ">") + // makes one allocation of strings for the entire Hstore, rather than one allocation per value. + valueStrings := make([]string, 0, numPairsEstimate) + result := make(Hstore, numPairsEstimate) + first := true + for !p.atEnd() { + if !first { + err := p.consumePairSeparator() + if err != nil { + return nil, err } + } else { + first = false } + err := p.consumeExpectedByte('"') if err != nil { - return + return nil, err + } + + key, err := p.consumeDoubleQuoted() + if err != nil { + return nil, err + } + + err = p.consumeKVSeparator() + if err != nil { + return nil, err + } + + value, err := p.consumeDoubleQuotedOrNull() + if err != nil { + return nil, err + } + if value.Valid { + valueStrings = append(valueStrings, value.String) + result[key] = &valueStrings[len(valueStrings)-1] + } else { + result[key] = nil } - r, end = p.Consume() } - if state != hsNext { - err = errors.New("Improperly formatted hstore") - return - } - k = keys - v = values - return + + return result, nil } diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/int.go b/vendor/github.com/jackc/pgx/v5/pgtype/int.go index 1cda0ba..90a20a2 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/int.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/int.go @@ -33,7 +33,7 @@ func (dst *Int2) ScanInt64(n Int8) error { } if n.Int64 < math.MinInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", n.Int64) + return fmt.Errorf("%d is less than minimum value for Int2", n.Int64) } if n.Int64 > math.MaxInt16 { return fmt.Errorf("%d is greater than maximum value for Int2", n.Int64) @@ -593,7 +593,7 @@ func (dst *Int4) ScanInt64(n Int8) error { } if n.Int64 < math.MinInt32 { - return fmt.Errorf("%d is greater than maximum value for Int4", n.Int64) + return fmt.Errorf("%d is less than minimum value for Int4", n.Int64) } if n.Int64 > math.MaxInt32 { return fmt.Errorf("%d is greater than maximum value for Int4", n.Int64) @@ -1164,7 +1164,7 @@ func (dst *Int8) ScanInt64(n Int8) error { } if n.Int64 < math.MinInt64 { - return fmt.Errorf("%d is greater than maximum value for Int8", n.Int64) + return fmt.Errorf("%d is less than minimum value for Int8", n.Int64) } if n.Int64 > math.MaxInt64 { return fmt.Errorf("%d is greater than maximum value for Int8", n.Int64) diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/int.go.erb b/vendor/github.com/jackc/pgx/v5/pgtype/int.go.erb index 572408e..e0c8b7a 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/int.go.erb +++ b/vendor/github.com/jackc/pgx/v5/pgtype/int.go.erb @@ -3,6 +3,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "encoding/json" "fmt" "math" "strconv" @@ -34,7 +35,7 @@ func (dst *Int<%= pg_byte_size %>) ScanInt64(n Int8) error { } if n.Int64 < math.MinInt<%= pg_bit_size %> { - return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n.Int64) + return fmt.Errorf("%d is less than minimum value for Int<%= pg_byte_size %>", n.Int64) } if n.Int64 > math.MaxInt<%= pg_bit_size %> { return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n.Int64) diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/json.go b/vendor/github.com/jackc/pgx/v5/pgtype/json.go index d0d98fc..d332dd0 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/json.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/json.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql" "database/sql/driver" "encoding/json" "fmt" @@ -23,6 +24,19 @@ func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Encod return encodePlanJSONCodecEitherFormatString{} case []byte: return encodePlanJSONCodecEitherFormatByteSlice{} + + // Must come before trying wrap encode plans because a pointer to a struct may be unwrapped to a struct that can be + // marshalled. + // + // https://github.com/jackc/pgx/issues/1681 + case json.Marshaler: + return encodePlanJSONCodecEitherFormatMarshal{} + + // Cannot rely on driver.Valuer being handled later because anything can be marshalled. + // + // https://github.com/jackc/pgx/issues/1430 + case driver.Valuer: + return &encodePlanDriverValuer{m: m, oid: oid, formatCode: format} } // Because anything can be marshalled the normal wrapping in Map.PlanScan doesn't get a chance to run. So try the @@ -78,14 +92,36 @@ func (JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan switch target.(type) { case *string: return scanPlanAnyToString{} + + case **string: + // This is to fix **string scanning. It seems wrong to special case **string, but it's not clear what a better + // solution would be. + // + // https://github.com/jackc/pgx/issues/1470 -- **string + // https://github.com/jackc/pgx/issues/1691 -- ** anything else + + if wrapperPlan, nextDst, ok := TryPointerPointerScanPlan(target); ok { + if nextPlan := m.planScan(oid, format, nextDst); nextPlan != nil { + if _, failed := nextPlan.(*scanPlanFail); !failed { + wrapperPlan.SetNext(nextPlan) + return wrapperPlan + } + } + } + case *[]byte: return scanPlanJSONToByteSlice{} case BytesScanner: return scanPlanBinaryBytesToBytesScanner{} - default: - return scanPlanJSONToJSONUnmarshal{} + + // Cannot rely on sql.Scanner being handled later because scanPlanJSONToJSONUnmarshal will take precedence. + // + // https://github.com/jackc/pgx/issues/1418 + case sql.Scanner: + return &scanPlanSQLScanner{formatCode: format} } + return scanPlanJSONToJSONUnmarshal{} } type scanPlanAnyToString struct{} @@ -125,7 +161,7 @@ func (scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error { if dstValue.Kind() == reflect.Ptr { el := dstValue.Elem() switch el.Kind() { - case reflect.Ptr, reflect.Slice, reflect.Map: + case reflect.Ptr, reflect.Slice, reflect.Map, reflect.Interface: el.Set(reflect.Zero(el.Type())) return nil } diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/numeric.go b/vendor/github.com/jackc/pgx/v5/pgtype/numeric.go index a5f4ed3..0e58fd0 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/numeric.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/numeric.go @@ -33,23 +33,6 @@ var big10 *big.Int = big.NewInt(10) var big100 *big.Int = big.NewInt(100) var big1000 *big.Int = big.NewInt(1000) -var bigMaxInt8 *big.Int = big.NewInt(math.MaxInt8) -var bigMinInt8 *big.Int = big.NewInt(math.MinInt8) -var bigMaxInt16 *big.Int = big.NewInt(math.MaxInt16) -var bigMinInt16 *big.Int = big.NewInt(math.MinInt16) -var bigMaxInt32 *big.Int = big.NewInt(math.MaxInt32) -var bigMinInt32 *big.Int = big.NewInt(math.MinInt32) -var bigMaxInt64 *big.Int = big.NewInt(math.MaxInt64) -var bigMinInt64 *big.Int = big.NewInt(math.MinInt64) -var bigMaxInt *big.Int = big.NewInt(int64(maxInt)) -var bigMinInt *big.Int = big.NewInt(int64(minInt)) - -var bigMaxUint8 *big.Int = big.NewInt(math.MaxUint8) -var bigMaxUint16 *big.Int = big.NewInt(math.MaxUint16) -var bigMaxUint32 *big.Int = big.NewInt(math.MaxUint32) -var bigMaxUint64 *big.Int = (&big.Int{}).SetUint64(uint64(math.MaxUint64)) -var bigMaxUint *big.Int = (&big.Int{}).SetUint64(uint64(maxUint)) - var bigNBase *big.Int = big.NewInt(nbase) var bigNBaseX2 *big.Int = big.NewInt(nbase * nbase) var bigNBaseX3 *big.Int = big.NewInt(nbase * nbase * nbase) @@ -161,20 +144,20 @@ func (n *Numeric) toBigInt() (*big.Int, error) { } func parseNumericString(str string) (n *big.Int, exp int32, err error) { - parts := strings.SplitN(str, ".", 2) - digits := strings.Join(parts, "") + idx := strings.IndexByte(str, '.') - if len(parts) > 1 { - exp = int32(-len(parts[1])) - } else { - for len(digits) > 1 && digits[len(digits)-1] == '0' && digits[len(digits)-2] != '-' { - digits = digits[:len(digits)-1] + if idx == -1 { + for len(str) > 1 && str[len(str)-1] == '0' && str[len(str)-2] != '-' { + str = str[:len(str)-1] exp++ } + } else { + exp = int32(-(len(str) - idx - 1)) + str = str[:idx] + str[idx+1:] } accum := &big.Int{} - if _, ok := accum.SetString(digits, 10); !ok { + if _, ok := accum.SetString(str, 10); !ok { return nil, 0, fmt.Errorf("%s is not a number", str) } @@ -240,10 +223,29 @@ func (n Numeric) MarshalJSON() ([]byte, error) { return n.numberTextBytes(), nil } +func (n *Numeric) UnmarshalJSON(src []byte) error { + if bytes.Equal(src, []byte(`null`)) { + *n = Numeric{} + return nil + } + if bytes.Equal(src, []byte(`"NaN"`)) { + *n = Numeric{NaN: true, Valid: true} + return nil + } + return scanPlanTextAnyToNumericScanner{}.Scan(src, n) +} + // numberString returns a string of the number. undefined if NaN, infinite, or NULL func (n Numeric) numberTextBytes() []byte { intStr := n.Int.String() + buf := &bytes.Buffer{} + + if len(intStr) > 0 && intStr[:1] == "-" { + intStr = intStr[1:] + buf.WriteByte('-') + } + exp := int(n.Exp) if exp > 0 { buf.WriteString(intStr) diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/pgtype.go b/vendor/github.com/jackc/pgx/v5/pgtype/pgtype.go index 2aa96e6..59d833a 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/pgtype.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/pgtype.go @@ -44,7 +44,7 @@ const ( MacaddrOID = 829 InetOID = 869 BoolArrayOID = 1000 - QCharArrayOID = 1003 + QCharArrayOID = 1002 NameArrayOID = 1003 Int2ArrayOID = 1005 Int4ArrayOID = 1007 @@ -104,6 +104,8 @@ const ( TstzrangeArrayOID = 3911 Int8rangeOID = 3926 Int8rangeArrayOID = 3927 + JSONPathOID = 4072 + JSONPathArrayOID = 4073 Int4multirangeOID = 4451 NummultirangeOID = 4532 TsmultirangeOID = 4533 @@ -145,7 +147,7 @@ const ( BinaryFormatCode = 1 ) -// A Codec converts between Go and PostgreSQL values. +// A Codec converts between Go and PostgreSQL values. A Codec must not be mutated after it is registered with a Map. type Codec interface { // FormatSupported returns true if the format is supported. FormatSupported(int16) bool @@ -176,6 +178,7 @@ func (e *nullAssignmentError) Error() string { return fmt.Sprintf("cannot assign NULL to %T", e.dst) } +// Type represents a PostgreSQL data type. It must not be mutated after it is registered with a Map. type Type struct { Codec Codec Name string @@ -192,7 +195,8 @@ type Map struct { reflectTypeToType map[reflect.Type]*Type - memoizedScanPlans map[uint32]map[reflect.Type][2]ScanPlan + memoizedScanPlans map[uint32]map[reflect.Type][2]ScanPlan + memoizedEncodePlans map[uint32]map[reflect.Type][2]EncodePlan // TryWrapEncodePlanFuncs is a slice of functions that will wrap a value that cannot be encoded by the Codec. Every // time a wrapper is found the PlanEncode method will be recursively called with the new value. This allows several layers of wrappers @@ -208,13 +212,16 @@ type Map struct { } func NewMap() *Map { - m := &Map{ + defaultMapInitOnce.Do(initDefaultMap) + + return &Map{ oidToType: make(map[uint32]*Type), nameToType: make(map[string]*Type), reflectTypeToName: make(map[reflect.Type]string), oidToFormatCode: make(map[uint32]int16), - memoizedScanPlans: make(map[uint32]map[reflect.Type][2]ScanPlan), + memoizedScanPlans: make(map[uint32]map[reflect.Type][2]ScanPlan), + memoizedEncodePlans: make(map[uint32]map[reflect.Type][2]EncodePlan), TryWrapEncodePlanFuncs: []TryWrapEncodePlanFunc{ TryWrapDerefPointerEncodePlan, @@ -223,6 +230,7 @@ func NewMap() *Map { TryWrapStructEncodePlan, TryWrapSliceEncodePlan, TryWrapMultiDimSliceEncodePlan, + TryWrapArrayEncodePlan, }, TryWrapScanPlanFuncs: []TryWrapScanPlanFunc{ @@ -232,184 +240,12 @@ func NewMap() *Map { TryWrapStructScanPlan, TryWrapPtrSliceScanPlan, TryWrapPtrMultiDimSliceScanPlan, + TryWrapPtrArrayScanPlan, }, } - - // Base types - m.RegisterType(&Type{Name: "aclitem", OID: ACLItemOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) - m.RegisterType(&Type{Name: "bit", OID: BitOID, Codec: BitsCodec{}}) - m.RegisterType(&Type{Name: "bool", OID: BoolOID, Codec: BoolCodec{}}) - m.RegisterType(&Type{Name: "box", OID: BoxOID, Codec: BoxCodec{}}) - m.RegisterType(&Type{Name: "bpchar", OID: BPCharOID, Codec: TextCodec{}}) - m.RegisterType(&Type{Name: "bytea", OID: ByteaOID, Codec: ByteaCodec{}}) - m.RegisterType(&Type{Name: "char", OID: QCharOID, Codec: QCharCodec{}}) - m.RegisterType(&Type{Name: "cid", OID: CIDOID, Codec: Uint32Codec{}}) - m.RegisterType(&Type{Name: "cidr", OID: CIDROID, Codec: InetCodec{}}) - m.RegisterType(&Type{Name: "circle", OID: CircleOID, Codec: CircleCodec{}}) - m.RegisterType(&Type{Name: "date", OID: DateOID, Codec: DateCodec{}}) - m.RegisterType(&Type{Name: "float4", OID: Float4OID, Codec: Float4Codec{}}) - m.RegisterType(&Type{Name: "float8", OID: Float8OID, Codec: Float8Codec{}}) - m.RegisterType(&Type{Name: "inet", OID: InetOID, Codec: InetCodec{}}) - m.RegisterType(&Type{Name: "int2", OID: Int2OID, Codec: Int2Codec{}}) - m.RegisterType(&Type{Name: "int4", OID: Int4OID, Codec: Int4Codec{}}) - m.RegisterType(&Type{Name: "int8", OID: Int8OID, Codec: Int8Codec{}}) - m.RegisterType(&Type{Name: "interval", OID: IntervalOID, Codec: IntervalCodec{}}) - m.RegisterType(&Type{Name: "json", OID: JSONOID, Codec: JSONCodec{}}) - m.RegisterType(&Type{Name: "jsonb", OID: JSONBOID, Codec: JSONBCodec{}}) - m.RegisterType(&Type{Name: "line", OID: LineOID, Codec: LineCodec{}}) - m.RegisterType(&Type{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}}) - m.RegisterType(&Type{Name: "macaddr", OID: MacaddrOID, Codec: MacaddrCodec{}}) - m.RegisterType(&Type{Name: "name", OID: NameOID, Codec: TextCodec{}}) - m.RegisterType(&Type{Name: "numeric", OID: NumericOID, Codec: NumericCodec{}}) - m.RegisterType(&Type{Name: "oid", OID: OIDOID, Codec: Uint32Codec{}}) - m.RegisterType(&Type{Name: "path", OID: PathOID, Codec: PathCodec{}}) - m.RegisterType(&Type{Name: "point", OID: PointOID, Codec: PointCodec{}}) - m.RegisterType(&Type{Name: "polygon", OID: PolygonOID, Codec: PolygonCodec{}}) - m.RegisterType(&Type{Name: "record", OID: RecordOID, Codec: RecordCodec{}}) - m.RegisterType(&Type{Name: "text", OID: TextOID, Codec: TextCodec{}}) - m.RegisterType(&Type{Name: "tid", OID: TIDOID, Codec: TIDCodec{}}) - m.RegisterType(&Type{Name: "time", OID: TimeOID, Codec: TimeCodec{}}) - m.RegisterType(&Type{Name: "timestamp", OID: TimestampOID, Codec: TimestampCodec{}}) - m.RegisterType(&Type{Name: "timestamptz", OID: TimestamptzOID, Codec: TimestamptzCodec{}}) - m.RegisterType(&Type{Name: "unknown", OID: UnknownOID, Codec: TextCodec{}}) - m.RegisterType(&Type{Name: "uuid", OID: UUIDOID, Codec: UUIDCodec{}}) - m.RegisterType(&Type{Name: "varbit", OID: VarbitOID, Codec: BitsCodec{}}) - m.RegisterType(&Type{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}}) - m.RegisterType(&Type{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}}) - - // Range types - m.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementType: m.oidToType[DateOID]}}) - m.RegisterType(&Type{Name: "int4range", OID: Int4rangeOID, Codec: &RangeCodec{ElementType: m.oidToType[Int4OID]}}) - m.RegisterType(&Type{Name: "int8range", OID: Int8rangeOID, Codec: &RangeCodec{ElementType: m.oidToType[Int8OID]}}) - m.RegisterType(&Type{Name: "numrange", OID: NumrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[NumericOID]}}) - m.RegisterType(&Type{Name: "tsrange", OID: TsrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[TimestampOID]}}) - m.RegisterType(&Type{Name: "tstzrange", OID: TstzrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[TimestamptzOID]}}) - - // Multirange types - m.RegisterType(&Type{Name: "datemultirange", OID: DatemultirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[DaterangeOID]}}) - m.RegisterType(&Type{Name: "int4multirange", OID: Int4multirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[Int4rangeOID]}}) - m.RegisterType(&Type{Name: "int8multirange", OID: Int8multirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[Int8rangeOID]}}) - m.RegisterType(&Type{Name: "nummultirange", OID: NummultirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[NumrangeOID]}}) - m.RegisterType(&Type{Name: "tsmultirange", OID: TsmultirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[TsrangeOID]}}) - m.RegisterType(&Type{Name: "tstzmultirange", OID: TstzmultirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[TstzrangeOID]}}) - - // Array types - m.RegisterType(&Type{Name: "_aclitem", OID: ACLItemArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[ACLItemOID]}}) - m.RegisterType(&Type{Name: "_bit", OID: BitArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BitOID]}}) - m.RegisterType(&Type{Name: "_bool", OID: BoolArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BoolOID]}}) - m.RegisterType(&Type{Name: "_box", OID: BoxArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BoxOID]}}) - m.RegisterType(&Type{Name: "_bpchar", OID: BPCharArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BPCharOID]}}) - m.RegisterType(&Type{Name: "_bytea", OID: ByteaArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[ByteaOID]}}) - m.RegisterType(&Type{Name: "_char", OID: QCharArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[QCharOID]}}) - m.RegisterType(&Type{Name: "_cid", OID: CIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[CIDOID]}}) - m.RegisterType(&Type{Name: "_cidr", OID: CIDRArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[CIDROID]}}) - m.RegisterType(&Type{Name: "_circle", OID: CircleArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[CircleOID]}}) - m.RegisterType(&Type{Name: "_date", OID: DateArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[DateOID]}}) - m.RegisterType(&Type{Name: "_daterange", OID: DaterangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[DaterangeOID]}}) - m.RegisterType(&Type{Name: "_float4", OID: Float4ArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Float4OID]}}) - m.RegisterType(&Type{Name: "_float8", OID: Float8ArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Float8OID]}}) - m.RegisterType(&Type{Name: "_inet", OID: InetArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[InetOID]}}) - m.RegisterType(&Type{Name: "_int2", OID: Int2ArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Int2OID]}}) - m.RegisterType(&Type{Name: "_int4", OID: Int4ArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Int4OID]}}) - m.RegisterType(&Type{Name: "_int4range", OID: Int4rangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Int4rangeOID]}}) - m.RegisterType(&Type{Name: "_int8", OID: Int8ArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Int8OID]}}) - m.RegisterType(&Type{Name: "_int8range", OID: Int8rangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Int8rangeOID]}}) - m.RegisterType(&Type{Name: "_interval", OID: IntervalArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[IntervalOID]}}) - m.RegisterType(&Type{Name: "_json", OID: JSONArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[JSONOID]}}) - m.RegisterType(&Type{Name: "_jsonb", OID: JSONBArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[JSONBOID]}}) - m.RegisterType(&Type{Name: "_line", OID: LineArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[LineOID]}}) - m.RegisterType(&Type{Name: "_lseg", OID: LsegArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[LsegOID]}}) - m.RegisterType(&Type{Name: "_macaddr", OID: MacaddrArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[MacaddrOID]}}) - m.RegisterType(&Type{Name: "_name", OID: NameArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[NameOID]}}) - m.RegisterType(&Type{Name: "_numeric", OID: NumericArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[NumericOID]}}) - m.RegisterType(&Type{Name: "_numrange", OID: NumrangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[NumrangeOID]}}) - m.RegisterType(&Type{Name: "_oid", OID: OIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[OIDOID]}}) - m.RegisterType(&Type{Name: "_path", OID: PathArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[PathOID]}}) - m.RegisterType(&Type{Name: "_point", OID: PointArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[PointOID]}}) - m.RegisterType(&Type{Name: "_polygon", OID: PolygonArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[PolygonOID]}}) - m.RegisterType(&Type{Name: "_record", OID: RecordArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[RecordOID]}}) - m.RegisterType(&Type{Name: "_text", OID: TextArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TextOID]}}) - m.RegisterType(&Type{Name: "_tid", OID: TIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TIDOID]}}) - m.RegisterType(&Type{Name: "_time", OID: TimeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TimeOID]}}) - m.RegisterType(&Type{Name: "_timestamp", OID: TimestampArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TimestampOID]}}) - m.RegisterType(&Type{Name: "_timestamptz", OID: TimestamptzArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TimestamptzOID]}}) - m.RegisterType(&Type{Name: "_tsrange", OID: TsrangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TsrangeOID]}}) - m.RegisterType(&Type{Name: "_tstzrange", OID: TstzrangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TstzrangeOID]}}) - m.RegisterType(&Type{Name: "_uuid", OID: UUIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[UUIDOID]}}) - m.RegisterType(&Type{Name: "_varbit", OID: VarbitArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[VarbitOID]}}) - m.RegisterType(&Type{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[VarcharOID]}}) - m.RegisterType(&Type{Name: "_xid", OID: XIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[XIDOID]}}) - - // Integer types that directly map to a PostgreSQL type - registerDefaultPgTypeVariants[int16](m, "int2") - registerDefaultPgTypeVariants[int32](m, "int4") - registerDefaultPgTypeVariants[int64](m, "int8") - - // Integer types that do not have a direct match to a PostgreSQL type - registerDefaultPgTypeVariants[int8](m, "int8") - registerDefaultPgTypeVariants[int](m, "int8") - registerDefaultPgTypeVariants[uint8](m, "int8") - registerDefaultPgTypeVariants[uint16](m, "int8") - registerDefaultPgTypeVariants[uint32](m, "int8") - registerDefaultPgTypeVariants[uint64](m, "numeric") - registerDefaultPgTypeVariants[uint](m, "numeric") - - registerDefaultPgTypeVariants[float32](m, "float4") - registerDefaultPgTypeVariants[float64](m, "float8") - - registerDefaultPgTypeVariants[bool](m, "bool") - registerDefaultPgTypeVariants[time.Time](m, "timestamptz") - registerDefaultPgTypeVariants[time.Duration](m, "interval") - registerDefaultPgTypeVariants[string](m, "text") - registerDefaultPgTypeVariants[[]byte](m, "bytea") - - registerDefaultPgTypeVariants[net.IP](m, "inet") - registerDefaultPgTypeVariants[net.IPNet](m, "cidr") - registerDefaultPgTypeVariants[netip.Addr](m, "inet") - registerDefaultPgTypeVariants[netip.Prefix](m, "cidr") - - // pgtype provided structs - registerDefaultPgTypeVariants[Bits](m, "varbit") - registerDefaultPgTypeVariants[Bool](m, "bool") - registerDefaultPgTypeVariants[Box](m, "box") - registerDefaultPgTypeVariants[Circle](m, "circle") - registerDefaultPgTypeVariants[Date](m, "date") - registerDefaultPgTypeVariants[Range[Date]](m, "daterange") - registerDefaultPgTypeVariants[Multirange[Range[Date]]](m, "datemultirange") - registerDefaultPgTypeVariants[Float4](m, "float4") - registerDefaultPgTypeVariants[Float8](m, "float8") - registerDefaultPgTypeVariants[Range[Float8]](m, "numrange") // There is no PostgreSQL builtin float8range so map it to numrange. - registerDefaultPgTypeVariants[Multirange[Range[Float8]]](m, "nummultirange") // There is no PostgreSQL builtin float8multirange so map it to nummultirange. - registerDefaultPgTypeVariants[Int2](m, "int2") - registerDefaultPgTypeVariants[Int4](m, "int4") - registerDefaultPgTypeVariants[Range[Int4]](m, "int4range") - registerDefaultPgTypeVariants[Multirange[Range[Int4]]](m, "int4multirange") - registerDefaultPgTypeVariants[Int8](m, "int8") - registerDefaultPgTypeVariants[Range[Int8]](m, "int8range") - registerDefaultPgTypeVariants[Multirange[Range[Int8]]](m, "int8multirange") - registerDefaultPgTypeVariants[Interval](m, "interval") - registerDefaultPgTypeVariants[Line](m, "line") - registerDefaultPgTypeVariants[Lseg](m, "lseg") - registerDefaultPgTypeVariants[Numeric](m, "numeric") - registerDefaultPgTypeVariants[Range[Numeric]](m, "numrange") - registerDefaultPgTypeVariants[Multirange[Range[Numeric]]](m, "nummultirange") - registerDefaultPgTypeVariants[Path](m, "path") - registerDefaultPgTypeVariants[Point](m, "point") - registerDefaultPgTypeVariants[Polygon](m, "polygon") - registerDefaultPgTypeVariants[TID](m, "tid") - registerDefaultPgTypeVariants[Text](m, "text") - registerDefaultPgTypeVariants[Time](m, "time") - registerDefaultPgTypeVariants[Timestamp](m, "timestamp") - registerDefaultPgTypeVariants[Timestamptz](m, "timestamptz") - registerDefaultPgTypeVariants[Range[Timestamp]](m, "tsrange") - registerDefaultPgTypeVariants[Multirange[Range[Timestamp]]](m, "tsmultirange") - registerDefaultPgTypeVariants[Range[Timestamptz]](m, "tstzrange") - registerDefaultPgTypeVariants[Multirange[Range[Timestamptz]]](m, "tstzmultirange") - registerDefaultPgTypeVariants[UUID](m, "uuid") - - return m } +// RegisterType registers a data type with the Map. t must not be mutated after it is registered. func (m *Map) RegisterType(t *Type) { m.oidToType[t.OID] = t m.nameToType[t.Name] = t @@ -420,6 +256,9 @@ func (m *Map) RegisterType(t *Type) { for k := range m.memoizedScanPlans { delete(m.memoizedScanPlans, k) } + for k := range m.memoizedEncodePlans { + delete(m.memoizedEncodePlans, k) + } } // RegisterDefaultPgType registers a mapping of a Go type to a PostgreSQL type name. Typically the data type to be @@ -433,15 +272,27 @@ func (m *Map) RegisterDefaultPgType(value any, name string) { for k := range m.memoizedScanPlans { delete(m.memoizedScanPlans, k) } + for k := range m.memoizedEncodePlans { + delete(m.memoizedEncodePlans, k) + } } +// TypeForOID returns the Type registered for the given OID. The returned Type must not be mutated. func (m *Map) TypeForOID(oid uint32) (*Type, bool) { - dt, ok := m.oidToType[oid] + if dt, ok := m.oidToType[oid]; ok { + return dt, true + } + + dt, ok := defaultMap.oidToType[oid] return dt, ok } +// TypeForName returns the Type registered for the given name. The returned Type must not be mutated. func (m *Map) TypeForName(name string) (*Type, bool) { - dt, ok := m.nameToType[name] + if dt, ok := m.nameToType[name]; ok { + return dt, true + } + dt, ok := defaultMap.nameToType[name] return dt, ok } @@ -449,30 +300,39 @@ func (m *Map) buildReflectTypeToType() { m.reflectTypeToType = make(map[reflect.Type]*Type) for reflectType, name := range m.reflectTypeToName { - if dt, ok := m.nameToType[name]; ok { + if dt, ok := m.TypeForName(name); ok { m.reflectTypeToType[reflectType] = dt } } } // TypeForValue finds a data type suitable for v. Use RegisterType to register types that can encode and decode -// themselves. Use RegisterDefaultPgType to register that can be handled by a registered data type. +// themselves. Use RegisterDefaultPgType to register that can be handled by a registered data type. The returned Type +// must not be mutated. func (m *Map) TypeForValue(v any) (*Type, bool) { if m.reflectTypeToType == nil { m.buildReflectTypeToType() } - dt, ok := m.reflectTypeToType[reflect.TypeOf(v)] + if dt, ok := m.reflectTypeToType[reflect.TypeOf(v)]; ok { + return dt, true + } + + dt, ok := defaultMap.reflectTypeToType[reflect.TypeOf(v)] return dt, ok } // FormatCodeForOID returns the preferred format code for type oid. If the type is not registered it returns the text // format code. func (m *Map) FormatCodeForOID(oid uint32) int16 { - fc, ok := m.oidToFormatCode[oid] - if ok { + if fc, ok := m.oidToFormatCode[oid]; ok { return fc } + + if fc, ok := defaultMap.oidToFormatCode[oid]; ok { + return fc + } + return TextFormatCode } @@ -573,6 +433,14 @@ func (plan *scanPlanFail) Scan(src []byte, dst any) error { return plan.Scan(src, dst) } } + for oid := range defaultMap.oidToType { + if _, ok := plan.m.oidToType[oid]; !ok { + plan := plan.m.planScan(oid, plan.formatCode, dst) + if _, ok := plan.(*scanPlanFail); !ok { + return plan.Scan(src, dst) + } + } + } } var format string @@ -586,7 +454,7 @@ func (plan *scanPlanFail) Scan(src []byte, dst any) error { } var dataTypeName string - if t, ok := plan.m.oidToType[plan.oid]; ok { + if t, ok := plan.m.TypeForOID(plan.oid); ok { dataTypeName = t.Name } else { dataTypeName = "unknown type" @@ -652,6 +520,7 @@ var elemKindToPointerTypes map[reflect.Kind]reflect.Type = map[reflect.Kind]refl reflect.Float32: reflect.TypeOf(new(float32)), reflect.Float64: reflect.TypeOf(new(float64)), reflect.String: reflect.TypeOf(new(string)), + reflect.Bool: reflect.TypeOf(new(bool)), } type underlyingTypeScanPlan struct { @@ -1018,7 +887,7 @@ func TryWrapStructScanPlan(target any) (plan WrappedScanPlanNextSetter, nextValu var targetElemValue reflect.Value if targetValue.IsNil() { - targetElemValue = reflect.New(targetValue.Type().Elem()) + targetElemValue = reflect.Zero(targetValue.Type().Elem()) } else { targetElemValue = targetValue.Elem() } @@ -1075,15 +944,16 @@ func TryWrapPtrSliceScanPlan(target any) (plan WrappedScanPlanNextSetter, nextVa return &wrapPtrSliceScanPlan[time.Time]{}, (*FlatArray[time.Time])(target), true } - targetValue := reflect.ValueOf(target) - if targetValue.Kind() != reflect.Ptr { + targetType := reflect.TypeOf(target) + if targetType.Kind() != reflect.Ptr { return nil, nil, false } - targetElemValue := targetValue.Elem() + targetElemType := targetType.Elem() - if targetElemValue.Kind() == reflect.Slice { - return &wrapPtrSliceReflectScanPlan{}, &anySliceArrayReflect{slice: targetElemValue}, true + if targetElemType.Kind() == reflect.Slice { + slice := reflect.New(targetElemType).Elem() + return &wrapPtrSliceReflectScanPlan{}, &anySliceArrayReflect{slice: slice}, true } return nil, nil, false } @@ -1139,6 +1009,31 @@ func (plan *wrapPtrMultiDimSliceScanPlan) Scan(src []byte, target any) error { return plan.next.Scan(src, &anyMultiDimSliceArray{slice: reflect.ValueOf(target).Elem()}) } +// TryWrapPtrArrayScanPlan tries to wrap a pointer to a single dimension array. +func TryWrapPtrArrayScanPlan(target any) (plan WrappedScanPlanNextSetter, nextValue any, ok bool) { + targetValue := reflect.ValueOf(target) + if targetValue.Kind() != reflect.Ptr { + return nil, nil, false + } + + targetElemValue := targetValue.Elem() + + if targetElemValue.Kind() == reflect.Array { + return &wrapPtrArrayReflectScanPlan{}, &anyArrayArrayReflect{array: targetElemValue}, true + } + return nil, nil, false +} + +type wrapPtrArrayReflectScanPlan struct { + next ScanPlan +} + +func (plan *wrapPtrArrayReflectScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapPtrArrayReflectScanPlan) Scan(src []byte, target any) error { + return plan.next.Scan(src, &anyArrayArrayReflect{array: reflect.ValueOf(target).Elem()}) +} + // PlanScan prepares a plan to scan a value into target. func (m *Map) PlanScan(oid uint32, formatCode int16, target any) ScanPlan { oidMemo := m.memoizedScanPlans[oid] @@ -1159,6 +1054,10 @@ func (m *Map) PlanScan(oid uint32, formatCode int16, target any) ScanPlan { } func (m *Map) planScan(oid uint32, formatCode int16, target any) ScanPlan { + if target == nil { + return &scanPlanFail{m: m, oid: oid, formatCode: formatCode} + } + if _, ok := target.(*UndecodedBytes); ok { return scanPlanAnyToUndecodedBytes{} } @@ -1200,6 +1099,18 @@ func (m *Map) planScan(oid uint32, formatCode int16, target any) ScanPlan { } } + // This needs to happen before trying m.TryWrapScanPlanFuncs. Otherwise, a sql.Scanner would not get called if it was + // defined on a type that could be unwrapped such as `type myString string`. + // + // https://github.com/jackc/pgtype/issues/197 + if _, ok := target.(sql.Scanner); ok { + if dt == nil { + return &scanPlanSQLScanner{formatCode: formatCode} + } else { + return &scanPlanCodecSQLScanner{c: dt.Codec, m: m, oid: oid, formatCode: formatCode} + } + } + for _, f := range m.TryWrapScanPlanFuncs { if wrapperPlan, nextDst, ok := f(target); ok { if nextPlan := m.planScan(oid, formatCode, nextDst); nextPlan != nil { @@ -1215,14 +1126,6 @@ func (m *Map) planScan(oid uint32, formatCode int16, target any) ScanPlan { if _, ok := target.(*any); ok { return &pointerEmptyInterfaceScanPlan{codec: dt.Codec, m: m, oid: oid, formatCode: formatCode} } - - if _, ok := target.(sql.Scanner); ok { - return &scanPlanCodecSQLScanner{c: dt.Codec, m: m, oid: oid, formatCode: formatCode} - } - } - - if _, ok := target.(sql.Scanner); ok { - return &scanPlanSQLScanner{formatCode: formatCode} } return &scanPlanFail{m: m, oid: oid, formatCode: formatCode} @@ -1237,25 +1140,6 @@ func (m *Map) Scan(oid uint32, formatCode int16, src []byte, dst any) error { return plan.Scan(src, dst) } -func scanUnknownType(oid uint32, formatCode int16, buf []byte, dest any) error { - switch dest := dest.(type) { - case *string: - if formatCode == BinaryFormatCode { - return fmt.Errorf("unknown oid %d in binary format cannot be scanned into %T", oid, dest) - } - *dest = string(buf) - return nil - case *[]byte: - *dest = buf - return nil - default: - if nextDst, retry := GetAssignToDstType(dest); retry { - return scanUnknownType(oid, formatCode, buf, nextDst) - } - return fmt.Errorf("unknown oid %d cannot be scanned into %T", oid, dest) - } -} - var ErrScanTargetTypeChanged = errors.New("scan target type changed") func codecScan(codec Codec, m *Map, oid uint32, format int16, src []byte, dst any) error { @@ -1289,6 +1173,24 @@ func codecDecodeToTextFormat(codec Codec, m *Map, oid uint32, format int16, src // PlanEncode returns an Encode plan for encoding value into PostgreSQL format for oid and format. If no plan can be // found then nil is returned. func (m *Map) PlanEncode(oid uint32, format int16, value any) EncodePlan { + oidMemo := m.memoizedEncodePlans[oid] + if oidMemo == nil { + oidMemo = make(map[reflect.Type][2]EncodePlan) + m.memoizedEncodePlans[oid] = oidMemo + } + targetReflectType := reflect.TypeOf(value) + typeMemo := oidMemo[targetReflectType] + plan := typeMemo[format] + if plan == nil { + plan = m.planEncode(oid, format, value) + typeMemo[format] = plan + oidMemo[targetReflectType] = typeMemo + } + + return plan +} + +func (m *Map) planEncode(oid uint32, format int16, value any) EncodePlan { if format == TextFormatCode { switch value.(type) { case string: @@ -1299,16 +1201,16 @@ func (m *Map) PlanEncode(oid uint32, format int16, value any) EncodePlan { } var dt *Type - - if oid == 0 { + if dataType, ok := m.TypeForOID(oid); ok { + dt = dataType + } else { + // If no type for the OID was found, then either it is unknowable (e.g. the simple protocol) or it is an + // unregistered type. In either case try to find the type and OID that matches the value (e.g. a []byte would be + // registered to PostgreSQL bytea). if dataType, ok := m.TypeForValue(value); ok { dt = dataType oid = dt.OID // Preserve assumed OID in case we are recursively called below. } - } else { - if dataType, ok := m.TypeForOID(oid); ok { - dt = dataType - } } if dt != nil { @@ -1453,6 +1355,7 @@ var kindToTypes map[reflect.Kind]reflect.Type = map[reflect.Kind]reflect.Type{ reflect.Float32: reflect.TypeOf(float32(0)), reflect.Float64: reflect.TypeOf(float64(0)), reflect.String: reflect.TypeOf(""), + reflect.Bool: reflect.TypeOf(false), } type underlyingTypeEncodePlan struct { @@ -1884,11 +1787,7 @@ type wrapSliceEncodePlan[T any] struct { func (plan *wrapSliceEncodePlan[T]) SetNext(next EncodePlan) { plan.next = next } func (plan *wrapSliceEncodePlan[T]) Encode(value any, buf []byte) (newBuf []byte, err error) { - w := anySliceArrayReflect{ - slice: reflect.ValueOf(value), - } - - return plan.next.Encode(w, buf) + return plan.next.Encode((FlatArray[T])(value.([]T)), buf) } type wrapSliceEncodeReflectPlan struct { @@ -1941,6 +1840,35 @@ func (plan *wrapMultiDimSliceEncodePlan) Encode(value any, buf []byte) (newBuf [ return plan.next.Encode(&w, buf) } +func TryWrapArrayEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) { + if _, ok := value.(driver.Valuer); ok { + return nil, nil, false + } + + if valueType := reflect.TypeOf(value); valueType != nil && valueType.Kind() == reflect.Array { + w := anyArrayArrayReflect{ + array: reflect.ValueOf(value), + } + return &wrapArrayEncodeReflectPlan{}, w, true + } + + return nil, nil, false +} + +type wrapArrayEncodeReflectPlan struct { + next EncodePlan +} + +func (plan *wrapArrayEncodeReflectPlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapArrayEncodeReflectPlan) Encode(value any, buf []byte) (newBuf []byte, err error) { + w := anyArrayArrayReflect{ + array: reflect.ValueOf(value), + } + + return plan.next.Encode(w, buf) +} + func newEncodeError(value any, m *Map, oid uint32, formatCode int16, err error) error { var format string switch formatCode { @@ -1953,13 +1881,13 @@ func newEncodeError(value any, m *Map, oid uint32, formatCode int16, err error) } var dataTypeName string - if t, ok := m.oidToType[oid]; ok { + if t, ok := m.TypeForOID(oid); ok { dataTypeName = t.Name } else { dataTypeName = "unknown type" } - return fmt.Errorf("unable to encode %#v into %s format for %s (OID %d): %s", value, format, dataTypeName, oid, err) + return fmt.Errorf("unable to encode %#v into %s format for %s (OID %d): %w", value, format, dataTypeName, oid, err) } // Encode appends the encoded bytes of value to buf. If value is the SQL value NULL then append nothing and return diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/pgtype_default.go b/vendor/github.com/jackc/pgx/v5/pgtype/pgtype_default.go new file mode 100644 index 0000000..58f4b92 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/pgtype/pgtype_default.go @@ -0,0 +1,223 @@ +package pgtype + +import ( + "net" + "net/netip" + "reflect" + "sync" + "time" +) + +var ( + // defaultMap contains default mappings between PostgreSQL server types and Go type handling logic. + defaultMap *Map + defaultMapInitOnce = sync.Once{} +) + +func initDefaultMap() { + defaultMap = &Map{ + oidToType: make(map[uint32]*Type), + nameToType: make(map[string]*Type), + reflectTypeToName: make(map[reflect.Type]string), + oidToFormatCode: make(map[uint32]int16), + + memoizedScanPlans: make(map[uint32]map[reflect.Type][2]ScanPlan), + memoizedEncodePlans: make(map[uint32]map[reflect.Type][2]EncodePlan), + + TryWrapEncodePlanFuncs: []TryWrapEncodePlanFunc{ + TryWrapDerefPointerEncodePlan, + TryWrapBuiltinTypeEncodePlan, + TryWrapFindUnderlyingTypeEncodePlan, + TryWrapStructEncodePlan, + TryWrapSliceEncodePlan, + TryWrapMultiDimSliceEncodePlan, + TryWrapArrayEncodePlan, + }, + + TryWrapScanPlanFuncs: []TryWrapScanPlanFunc{ + TryPointerPointerScanPlan, + TryWrapBuiltinTypeScanPlan, + TryFindUnderlyingTypeScanPlan, + TryWrapStructScanPlan, + TryWrapPtrSliceScanPlan, + TryWrapPtrMultiDimSliceScanPlan, + TryWrapPtrArrayScanPlan, + }, + } + + // Base types + defaultMap.RegisterType(&Type{Name: "aclitem", OID: ACLItemOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) + defaultMap.RegisterType(&Type{Name: "bit", OID: BitOID, Codec: BitsCodec{}}) + defaultMap.RegisterType(&Type{Name: "bool", OID: BoolOID, Codec: BoolCodec{}}) + defaultMap.RegisterType(&Type{Name: "box", OID: BoxOID, Codec: BoxCodec{}}) + defaultMap.RegisterType(&Type{Name: "bpchar", OID: BPCharOID, Codec: TextCodec{}}) + defaultMap.RegisterType(&Type{Name: "bytea", OID: ByteaOID, Codec: ByteaCodec{}}) + defaultMap.RegisterType(&Type{Name: "char", OID: QCharOID, Codec: QCharCodec{}}) + defaultMap.RegisterType(&Type{Name: "cid", OID: CIDOID, Codec: Uint32Codec{}}) + defaultMap.RegisterType(&Type{Name: "cidr", OID: CIDROID, Codec: InetCodec{}}) + defaultMap.RegisterType(&Type{Name: "circle", OID: CircleOID, Codec: CircleCodec{}}) + defaultMap.RegisterType(&Type{Name: "date", OID: DateOID, Codec: DateCodec{}}) + defaultMap.RegisterType(&Type{Name: "float4", OID: Float4OID, Codec: Float4Codec{}}) + defaultMap.RegisterType(&Type{Name: "float8", OID: Float8OID, Codec: Float8Codec{}}) + defaultMap.RegisterType(&Type{Name: "inet", OID: InetOID, Codec: InetCodec{}}) + defaultMap.RegisterType(&Type{Name: "int2", OID: Int2OID, Codec: Int2Codec{}}) + defaultMap.RegisterType(&Type{Name: "int4", OID: Int4OID, Codec: Int4Codec{}}) + defaultMap.RegisterType(&Type{Name: "int8", OID: Int8OID, Codec: Int8Codec{}}) + defaultMap.RegisterType(&Type{Name: "interval", OID: IntervalOID, Codec: IntervalCodec{}}) + defaultMap.RegisterType(&Type{Name: "json", OID: JSONOID, Codec: JSONCodec{}}) + defaultMap.RegisterType(&Type{Name: "jsonb", OID: JSONBOID, Codec: JSONBCodec{}}) + defaultMap.RegisterType(&Type{Name: "jsonpath", OID: JSONPathOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) + defaultMap.RegisterType(&Type{Name: "line", OID: LineOID, Codec: LineCodec{}}) + defaultMap.RegisterType(&Type{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}}) + defaultMap.RegisterType(&Type{Name: "macaddr", OID: MacaddrOID, Codec: MacaddrCodec{}}) + defaultMap.RegisterType(&Type{Name: "name", OID: NameOID, Codec: TextCodec{}}) + defaultMap.RegisterType(&Type{Name: "numeric", OID: NumericOID, Codec: NumericCodec{}}) + defaultMap.RegisterType(&Type{Name: "oid", OID: OIDOID, Codec: Uint32Codec{}}) + defaultMap.RegisterType(&Type{Name: "path", OID: PathOID, Codec: PathCodec{}}) + defaultMap.RegisterType(&Type{Name: "point", OID: PointOID, Codec: PointCodec{}}) + defaultMap.RegisterType(&Type{Name: "polygon", OID: PolygonOID, Codec: PolygonCodec{}}) + defaultMap.RegisterType(&Type{Name: "record", OID: RecordOID, Codec: RecordCodec{}}) + defaultMap.RegisterType(&Type{Name: "text", OID: TextOID, Codec: TextCodec{}}) + defaultMap.RegisterType(&Type{Name: "tid", OID: TIDOID, Codec: TIDCodec{}}) + defaultMap.RegisterType(&Type{Name: "time", OID: TimeOID, Codec: TimeCodec{}}) + defaultMap.RegisterType(&Type{Name: "timestamp", OID: TimestampOID, Codec: TimestampCodec{}}) + defaultMap.RegisterType(&Type{Name: "timestamptz", OID: TimestamptzOID, Codec: TimestamptzCodec{}}) + defaultMap.RegisterType(&Type{Name: "unknown", OID: UnknownOID, Codec: TextCodec{}}) + defaultMap.RegisterType(&Type{Name: "uuid", OID: UUIDOID, Codec: UUIDCodec{}}) + defaultMap.RegisterType(&Type{Name: "varbit", OID: VarbitOID, Codec: BitsCodec{}}) + defaultMap.RegisterType(&Type{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}}) + defaultMap.RegisterType(&Type{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}}) + + // Range types + defaultMap.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[DateOID]}}) + defaultMap.RegisterType(&Type{Name: "int4range", OID: Int4rangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[Int4OID]}}) + defaultMap.RegisterType(&Type{Name: "int8range", OID: Int8rangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[Int8OID]}}) + defaultMap.RegisterType(&Type{Name: "numrange", OID: NumrangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[NumericOID]}}) + defaultMap.RegisterType(&Type{Name: "tsrange", OID: TsrangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[TimestampOID]}}) + defaultMap.RegisterType(&Type{Name: "tstzrange", OID: TstzrangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[TimestamptzOID]}}) + + // Multirange types + defaultMap.RegisterType(&Type{Name: "datemultirange", OID: DatemultirangeOID, Codec: &MultirangeCodec{ElementType: defaultMap.oidToType[DaterangeOID]}}) + defaultMap.RegisterType(&Type{Name: "int4multirange", OID: Int4multirangeOID, Codec: &MultirangeCodec{ElementType: defaultMap.oidToType[Int4rangeOID]}}) + defaultMap.RegisterType(&Type{Name: "int8multirange", OID: Int8multirangeOID, Codec: &MultirangeCodec{ElementType: defaultMap.oidToType[Int8rangeOID]}}) + defaultMap.RegisterType(&Type{Name: "nummultirange", OID: NummultirangeOID, Codec: &MultirangeCodec{ElementType: defaultMap.oidToType[NumrangeOID]}}) + defaultMap.RegisterType(&Type{Name: "tsmultirange", OID: TsmultirangeOID, Codec: &MultirangeCodec{ElementType: defaultMap.oidToType[TsrangeOID]}}) + defaultMap.RegisterType(&Type{Name: "tstzmultirange", OID: TstzmultirangeOID, Codec: &MultirangeCodec{ElementType: defaultMap.oidToType[TstzrangeOID]}}) + + // Array types + defaultMap.RegisterType(&Type{Name: "_aclitem", OID: ACLItemArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[ACLItemOID]}}) + defaultMap.RegisterType(&Type{Name: "_bit", OID: BitArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[BitOID]}}) + defaultMap.RegisterType(&Type{Name: "_bool", OID: BoolArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[BoolOID]}}) + defaultMap.RegisterType(&Type{Name: "_box", OID: BoxArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[BoxOID]}}) + defaultMap.RegisterType(&Type{Name: "_bpchar", OID: BPCharArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[BPCharOID]}}) + defaultMap.RegisterType(&Type{Name: "_bytea", OID: ByteaArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[ByteaOID]}}) + defaultMap.RegisterType(&Type{Name: "_char", OID: QCharArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[QCharOID]}}) + defaultMap.RegisterType(&Type{Name: "_cid", OID: CIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[CIDOID]}}) + defaultMap.RegisterType(&Type{Name: "_cidr", OID: CIDRArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[CIDROID]}}) + defaultMap.RegisterType(&Type{Name: "_circle", OID: CircleArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[CircleOID]}}) + defaultMap.RegisterType(&Type{Name: "_date", OID: DateArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[DateOID]}}) + defaultMap.RegisterType(&Type{Name: "_daterange", OID: DaterangeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[DaterangeOID]}}) + defaultMap.RegisterType(&Type{Name: "_float4", OID: Float4ArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Float4OID]}}) + defaultMap.RegisterType(&Type{Name: "_float8", OID: Float8ArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Float8OID]}}) + defaultMap.RegisterType(&Type{Name: "_inet", OID: InetArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[InetOID]}}) + defaultMap.RegisterType(&Type{Name: "_int2", OID: Int2ArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Int2OID]}}) + defaultMap.RegisterType(&Type{Name: "_int4", OID: Int4ArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Int4OID]}}) + defaultMap.RegisterType(&Type{Name: "_int4range", OID: Int4rangeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Int4rangeOID]}}) + defaultMap.RegisterType(&Type{Name: "_int8", OID: Int8ArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Int8OID]}}) + defaultMap.RegisterType(&Type{Name: "_int8range", OID: Int8rangeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Int8rangeOID]}}) + defaultMap.RegisterType(&Type{Name: "_interval", OID: IntervalArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[IntervalOID]}}) + defaultMap.RegisterType(&Type{Name: "_json", OID: JSONArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[JSONOID]}}) + defaultMap.RegisterType(&Type{Name: "_jsonb", OID: JSONBArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[JSONBOID]}}) + defaultMap.RegisterType(&Type{Name: "_jsonpath", OID: JSONPathArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[JSONPathOID]}}) + defaultMap.RegisterType(&Type{Name: "_line", OID: LineArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[LineOID]}}) + defaultMap.RegisterType(&Type{Name: "_lseg", OID: LsegArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[LsegOID]}}) + defaultMap.RegisterType(&Type{Name: "_macaddr", OID: MacaddrArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[MacaddrOID]}}) + defaultMap.RegisterType(&Type{Name: "_name", OID: NameArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[NameOID]}}) + defaultMap.RegisterType(&Type{Name: "_numeric", OID: NumericArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[NumericOID]}}) + defaultMap.RegisterType(&Type{Name: "_numrange", OID: NumrangeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[NumrangeOID]}}) + defaultMap.RegisterType(&Type{Name: "_oid", OID: OIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[OIDOID]}}) + defaultMap.RegisterType(&Type{Name: "_path", OID: PathArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[PathOID]}}) + defaultMap.RegisterType(&Type{Name: "_point", OID: PointArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[PointOID]}}) + defaultMap.RegisterType(&Type{Name: "_polygon", OID: PolygonArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[PolygonOID]}}) + defaultMap.RegisterType(&Type{Name: "_record", OID: RecordArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[RecordOID]}}) + defaultMap.RegisterType(&Type{Name: "_text", OID: TextArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TextOID]}}) + defaultMap.RegisterType(&Type{Name: "_tid", OID: TIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TIDOID]}}) + defaultMap.RegisterType(&Type{Name: "_time", OID: TimeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TimeOID]}}) + defaultMap.RegisterType(&Type{Name: "_timestamp", OID: TimestampArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TimestampOID]}}) + defaultMap.RegisterType(&Type{Name: "_timestamptz", OID: TimestamptzArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TimestamptzOID]}}) + defaultMap.RegisterType(&Type{Name: "_tsrange", OID: TsrangeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TsrangeOID]}}) + defaultMap.RegisterType(&Type{Name: "_tstzrange", OID: TstzrangeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TstzrangeOID]}}) + defaultMap.RegisterType(&Type{Name: "_uuid", OID: UUIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[UUIDOID]}}) + defaultMap.RegisterType(&Type{Name: "_varbit", OID: VarbitArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[VarbitOID]}}) + defaultMap.RegisterType(&Type{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[VarcharOID]}}) + defaultMap.RegisterType(&Type{Name: "_xid", OID: XIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[XIDOID]}}) + + // Integer types that directly map to a PostgreSQL type + registerDefaultPgTypeVariants[int16](defaultMap, "int2") + registerDefaultPgTypeVariants[int32](defaultMap, "int4") + registerDefaultPgTypeVariants[int64](defaultMap, "int8") + + // Integer types that do not have a direct match to a PostgreSQL type + registerDefaultPgTypeVariants[int8](defaultMap, "int8") + registerDefaultPgTypeVariants[int](defaultMap, "int8") + registerDefaultPgTypeVariants[uint8](defaultMap, "int8") + registerDefaultPgTypeVariants[uint16](defaultMap, "int8") + registerDefaultPgTypeVariants[uint32](defaultMap, "int8") + registerDefaultPgTypeVariants[uint64](defaultMap, "numeric") + registerDefaultPgTypeVariants[uint](defaultMap, "numeric") + + registerDefaultPgTypeVariants[float32](defaultMap, "float4") + registerDefaultPgTypeVariants[float64](defaultMap, "float8") + + registerDefaultPgTypeVariants[bool](defaultMap, "bool") + registerDefaultPgTypeVariants[time.Time](defaultMap, "timestamptz") + registerDefaultPgTypeVariants[time.Duration](defaultMap, "interval") + registerDefaultPgTypeVariants[string](defaultMap, "text") + registerDefaultPgTypeVariants[[]byte](defaultMap, "bytea") + + registerDefaultPgTypeVariants[net.IP](defaultMap, "inet") + registerDefaultPgTypeVariants[net.IPNet](defaultMap, "cidr") + registerDefaultPgTypeVariants[netip.Addr](defaultMap, "inet") + registerDefaultPgTypeVariants[netip.Prefix](defaultMap, "cidr") + + // pgtype provided structs + registerDefaultPgTypeVariants[Bits](defaultMap, "varbit") + registerDefaultPgTypeVariants[Bool](defaultMap, "bool") + registerDefaultPgTypeVariants[Box](defaultMap, "box") + registerDefaultPgTypeVariants[Circle](defaultMap, "circle") + registerDefaultPgTypeVariants[Date](defaultMap, "date") + registerDefaultPgTypeVariants[Range[Date]](defaultMap, "daterange") + registerDefaultPgTypeVariants[Multirange[Range[Date]]](defaultMap, "datemultirange") + registerDefaultPgTypeVariants[Float4](defaultMap, "float4") + registerDefaultPgTypeVariants[Float8](defaultMap, "float8") + registerDefaultPgTypeVariants[Range[Float8]](defaultMap, "numrange") // There is no PostgreSQL builtin float8range so map it to numrange. + registerDefaultPgTypeVariants[Multirange[Range[Float8]]](defaultMap, "nummultirange") // There is no PostgreSQL builtin float8multirange so map it to nummultirange. + registerDefaultPgTypeVariants[Int2](defaultMap, "int2") + registerDefaultPgTypeVariants[Int4](defaultMap, "int4") + registerDefaultPgTypeVariants[Range[Int4]](defaultMap, "int4range") + registerDefaultPgTypeVariants[Multirange[Range[Int4]]](defaultMap, "int4multirange") + registerDefaultPgTypeVariants[Int8](defaultMap, "int8") + registerDefaultPgTypeVariants[Range[Int8]](defaultMap, "int8range") + registerDefaultPgTypeVariants[Multirange[Range[Int8]]](defaultMap, "int8multirange") + registerDefaultPgTypeVariants[Interval](defaultMap, "interval") + registerDefaultPgTypeVariants[Line](defaultMap, "line") + registerDefaultPgTypeVariants[Lseg](defaultMap, "lseg") + registerDefaultPgTypeVariants[Numeric](defaultMap, "numeric") + registerDefaultPgTypeVariants[Range[Numeric]](defaultMap, "numrange") + registerDefaultPgTypeVariants[Multirange[Range[Numeric]]](defaultMap, "nummultirange") + registerDefaultPgTypeVariants[Path](defaultMap, "path") + registerDefaultPgTypeVariants[Point](defaultMap, "point") + registerDefaultPgTypeVariants[Polygon](defaultMap, "polygon") + registerDefaultPgTypeVariants[TID](defaultMap, "tid") + registerDefaultPgTypeVariants[Text](defaultMap, "text") + registerDefaultPgTypeVariants[Time](defaultMap, "time") + registerDefaultPgTypeVariants[Timestamp](defaultMap, "timestamp") + registerDefaultPgTypeVariants[Timestamptz](defaultMap, "timestamptz") + registerDefaultPgTypeVariants[Range[Timestamp]](defaultMap, "tsrange") + registerDefaultPgTypeVariants[Multirange[Range[Timestamp]]](defaultMap, "tsmultirange") + registerDefaultPgTypeVariants[Range[Timestamptz]](defaultMap, "tstzrange") + registerDefaultPgTypeVariants[Multirange[Range[Timestamptz]]](defaultMap, "tstzmultirange") + registerDefaultPgTypeVariants[UUID](defaultMap, "uuid") + + defaultMap.buildReflectTypeToType() +} diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/point.go b/vendor/github.com/jackc/pgx/v5/pgtype/point.go index cfa5a9f..b5a4320 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/point.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/point.go @@ -40,7 +40,7 @@ func (p Point) PointValue() (Point, error) { } func parsePoint(src []byte) (*Point, error) { - if src == nil || bytes.Compare(src, []byte("null")) == 0 { + if src == nil || bytes.Equal(src, []byte("null")) { return &Point{}, nil } diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/tid.go b/vendor/github.com/jackc/pgx/v5/pgtype/tid.go index cb4a9ec..5839e87 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/tid.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/tid.go @@ -22,7 +22,7 @@ type TIDValuer interface { // // When one does // -// select ctid, * from some_table; +// select ctid, * from some_table; // // it is the data type of the ctid hidden system column. // diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/timestamp.go b/vendor/github.com/jackc/pgx/v5/pgtype/timestamp.go index 9f3de2c..35d7395 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/timestamp.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/timestamp.go @@ -3,6 +3,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "encoding/json" "fmt" "strings" "time" @@ -66,6 +67,55 @@ func (ts Timestamp) Value() (driver.Value, error) { return ts.Time, nil } +func (ts Timestamp) MarshalJSON() ([]byte, error) { + if !ts.Valid { + return []byte("null"), nil + } + + var s string + + switch ts.InfinityModifier { + case Finite: + s = ts.Time.Format(time.RFC3339Nano) + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + return json.Marshal(s) +} + +func (ts *Timestamp) UnmarshalJSON(b []byte) error { + var s *string + err := json.Unmarshal(b, &s) + if err != nil { + return err + } + + if s == nil { + *ts = Timestamp{} + return nil + } + + switch *s { + case "infinity": + *ts = Timestamp{Valid: true, InfinityModifier: Infinity} + case "-infinity": + *ts = Timestamp{Valid: true, InfinityModifier: -Infinity} + default: + // PostgreSQL uses ISO 8601 for to_json function and casting from a string to timestamptz + tim, err := time.Parse(time.RFC3339Nano, *s) + if err != nil { + return err + } + + *ts = Timestamp{Time: tim, Valid: true} + } + + return nil +} + type TimestampCodec struct{} func (TimestampCodec) FormatSupported(format int16) bool { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/uuid.go b/vendor/github.com/jackc/pgx/v5/pgtype/uuid.go index 96a4c32..b59d6e7 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/uuid.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/uuid.go @@ -97,7 +97,7 @@ func (src UUID) MarshalJSON() ([]byte, error) { } func (dst *UUID) UnmarshalJSON(src []byte) error { - if bytes.Compare(src, []byte("null")) == 0 { + if bytes.Equal(src, []byte("null")) { *dst = UUID{} return nil } diff --git a/vendor/github.com/jackc/pgx/v5/rows.go b/vendor/github.com/jackc/pgx/v5/rows.go index ffe739b..1b1c8ac 100644 --- a/vendor/github.com/jackc/pgx/v5/rows.go +++ b/vendor/github.com/jackc/pgx/v5/rows.go @@ -28,12 +28,16 @@ type Rows interface { // to call Close after rows is already closed. Close() - // Err returns any error that occurred while reading. + // Err returns any error that occurred while reading. Err must only be called after the Rows is closed (either by + // calling Close or by Next returning false). If it is called early it may return nil even if there was an error + // executing the query. Err() error // CommandTag returns the command tag from this query. It is only available after Rows is closed. CommandTag() pgconn.CommandTag + // FieldDescriptions returns the field descriptions of the columns. It may return nil. In particular this can occur + // when there was an error executing the query. FieldDescriptions() []pgconn.FieldDescription // Next prepares the next row for reading. It returns true if there is another @@ -227,7 +231,11 @@ func (rows *baseRows) Scan(dest ...any) error { if len(dest) == 1 { if rc, ok := dest[0].(RowScanner); ok { - return rc.ScanRow(rows) + err := rc.ScanRow(rows) + if err != nil { + rows.fatal(err) + } + return err } } @@ -298,7 +306,7 @@ func (rows *baseRows) Values() ([]any, error) { copy(newBuf, buf) values = append(values, newBuf) default: - rows.fatal(errors.New("Unknown format code")) + rows.fatal(errors.New("unknown format code")) } } @@ -488,7 +496,8 @@ func (rs *mapRowScanner) ScanRow(rows Rows) error { } // RowToStructByPos returns a T scanned from row. T must be a struct. T must have the same number a public fields as row -// has fields. The row and T fields will by matched by position. +// has fields. The row and T fields will by matched by position. If the "db" struct tag is "-" then the field will be +// ignored. func RowToStructByPos[T any](row CollectableRow) (T, error) { var value T err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value}) @@ -496,7 +505,8 @@ func RowToStructByPos[T any](row CollectableRow) (T, error) { } // RowToAddrOfStructByPos returns the address of a T scanned from row. T must be a struct. T must have the same number a -// public fields as row has fields. The row and T fields will by matched by position. +// public fields as row has fields. The row and T fields will by matched by position. If the "db" struct tag is "-" then +// the field will be ignored. func RowToAddrOfStructByPos[T any](row CollectableRow) (*T, error) { var value T err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value}) @@ -533,13 +543,16 @@ func (rs *positionalStructRowScanner) appendScanTargets(dstElemValue reflect.Val for i := 0; i < dstElemType.NumField(); i++ { sf := dstElemType.Field(i) - if sf.PkgPath == "" { - // Handle anonymous struct embedding, but do not try to handle embedded pointers. - if sf.Anonymous && sf.Type.Kind() == reflect.Struct { - scanTargets = rs.appendScanTargets(dstElemValue.Field(i), scanTargets) - } else { - scanTargets = append(scanTargets, dstElemValue.Field(i).Addr().Interface()) + // Handle anonymous struct embedding, but do not try to handle embedded pointers. + if sf.Anonymous && sf.Type.Kind() == reflect.Struct { + scanTargets = rs.appendScanTargets(dstElemValue.Field(i), scanTargets) + } else if sf.PkgPath == "" { + dbTag, _ := sf.Tag.Lookup(structTagKey) + if dbTag == "-" { + // Field is ignored, skip it. + continue } + scanTargets = append(scanTargets, dstElemValue.Field(i).Addr().Interface()) } } @@ -565,8 +578,28 @@ func RowToAddrOfStructByName[T any](row CollectableRow) (*T, error) { return &value, err } +// RowToStructByNameLax returns a T scanned from row. T must be a struct. T must have greater than or equal number of named public +// fields as row has fields. The row and T fields will by matched by name. The match is case-insensitive. The database +// column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored. +func RowToStructByNameLax[T any](row CollectableRow) (T, error) { + var value T + err := row.Scan(&namedStructRowScanner{ptrToStruct: &value, lax: true}) + return value, err +} + +// RowToAddrOfStructByNameLax returns the address of a T scanned from row. T must be a struct. T must have greater than or +// equal number of named public fields as row has fields. The row and T fields will by matched by name. The match is +// case-insensitive. The database column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" +// then the field will be ignored. +func RowToAddrOfStructByNameLax[T any](row CollectableRow) (*T, error) { + var value T + err := row.Scan(&namedStructRowScanner{ptrToStruct: &value, lax: true}) + return &value, err +} + type namedStructRowScanner struct { ptrToStruct any + lax bool } func (rs *namedStructRowScanner) ScanRow(rows Rows) error { @@ -578,7 +611,6 @@ func (rs *namedStructRowScanner) ScanRow(rows Rows) error { dstElemValue := dstValue.Elem() scanTargets, err := rs.appendScanTargets(dstElemValue, nil, rows.FieldDescriptions()) - if err != nil { return err } @@ -638,7 +670,13 @@ func (rs *namedStructRowScanner) appendScanTargets(dstElemValue reflect.Value, s colName = sf.Name } fpos := fieldPosByName(fldDescs, colName) - if fpos == -1 || fpos >= len(scanTargets) { + if fpos == -1 { + if rs.lax { + continue + } + return nil, fmt.Errorf("cannot find field %s in returned row", colName) + } + if fpos >= len(scanTargets) && !rs.lax { return nil, fmt.Errorf("cannot find field %s in returned row", colName) } scanTargets[fpos] = dstElemValue.Field(i).Addr().Interface() diff --git a/vendor/github.com/jackc/pgx/v5/stdlib/sql.go b/vendor/github.com/jackc/pgx/v5/stdlib/sql.go index fc0b023..97ecc9b 100644 --- a/vendor/github.com/jackc/pgx/v5/stdlib/sql.go +++ b/vendor/github.com/jackc/pgx/v5/stdlib/sql.go @@ -2,58 +2,58 @@ // // A database/sql connection can be established through sql.Open. // -// db, err := sql.Open("pgx", "postgres://pgx_md5:secret@localhost:5432/pgx_test?sslmode=disable") -// if err != nil { -// return err -// } +// db, err := sql.Open("pgx", "postgres://pgx_md5:secret@localhost:5432/pgx_test?sslmode=disable") +// if err != nil { +// return err +// } // // Or from a DSN string. // -// db, err := sql.Open("pgx", "user=postgres password=secret host=localhost port=5432 database=pgx_test sslmode=disable") -// if err != nil { -// return err -// } +// db, err := sql.Open("pgx", "user=postgres password=secret host=localhost port=5432 database=pgx_test sslmode=disable") +// if err != nil { +// return err +// } // // Or a pgx.ConnConfig can be used to set configuration not accessible via connection string. In this case the // pgx.ConnConfig must first be registered with the driver. This registration returns a connection string which is used // with sql.Open. // -// connConfig, _ := pgx.ParseConfig(os.Getenv("DATABASE_URL")) -// connConfig.Logger = myLogger -// connStr := stdlib.RegisterConnConfig(connConfig) -// db, _ := sql.Open("pgx", connStr) +// connConfig, _ := pgx.ParseConfig(os.Getenv("DATABASE_URL")) +// connConfig.Logger = myLogger +// connStr := stdlib.RegisterConnConfig(connConfig) +// db, _ := sql.Open("pgx", connStr) // // pgx uses standard PostgreSQL positional parameters in queries. e.g. $1, $2. It does not support named parameters. // -// db.QueryRow("select * from users where id=$1", userID) +// db.QueryRow("select * from users where id=$1", userID) // // (*sql.Conn) Raw() can be used to get a *pgx.Conn from the standard database/sql.DB connection pool. This allows // operations that use pgx specific functionality. // -// // Given db is a *sql.DB -// conn, err := db.Conn(context.Background()) -// if err != nil { -// // handle error from acquiring connection from DB pool -// } +// // Given db is a *sql.DB +// conn, err := db.Conn(context.Background()) +// if err != nil { +// // handle error from acquiring connection from DB pool +// } // -// err = conn.Raw(func(driverConn any) error { -// conn := driverConn.(*stdlib.Conn).Conn() // conn is a *pgx.Conn -// // Do pgx specific stuff with conn -// conn.CopyFrom(...) -// return nil -// }) -// if err != nil { -// // handle error that occurred while using *pgx.Conn -// } +// err = conn.Raw(func(driverConn any) error { +// conn := driverConn.(*stdlib.Conn).Conn() // conn is a *pgx.Conn +// // Do pgx specific stuff with conn +// conn.CopyFrom(...) +// return nil +// }) +// if err != nil { +// // handle error that occurred while using *pgx.Conn +// } // -// PostgreSQL Specific Data Types +// # PostgreSQL Specific Data Types // // The pgtype package provides support for PostgreSQL specific types. *pgtype.Map.SQLScanner is an adapter that makes // these types usable as a sql.Scanner. // -// m := pgtype.NewMap() -// var a []int64 -// err := db.QueryRow("select '{1,2,3}'::bigint[]").Scan(m.SQLScanner(&a)) +// m := pgtype.NewMap() +// var a []int64 +// err := db.QueryRow("select '{1,2,3}'::bigint[]").Scan(m.SQLScanner(&a)) package stdlib import ( @@ -85,7 +85,13 @@ func init() { pgxDriver = &Driver{ configs: make(map[string]*pgx.ConnConfig), } - sql.Register("pgx", pgxDriver) + + // if pgx driver was already registered by different pgx major version then we + // skip registration under the default name. + if !contains(sql.Drivers(), "pgx") { + sql.Register("pgx", pgxDriver) + } + sql.Register("pgx/v5", pgxDriver) databaseSQLResultFormats = pgx.QueryResultFormatsByOID{ pgtype.BoolOID: 1, @@ -104,6 +110,17 @@ func init() { } } +// TODO replace by slices.Contains when experimental package will be merged to stdlib +// https://pkg.go.dev/golang.org/x/exp/slices#Contains +func contains(list []string, y string) bool { + for _, x := range list { + if x == y { + return true + } + } + return false +} + // OptionOpenDB options for configuring the driver when opening a new db pool. type OptionOpenDB func(*connector) @@ -140,7 +157,7 @@ func RandomizeHostOrderFunc(ctx context.Context, connConfig *pgx.ConnConfig) err return nil } - newFallbacks := append([]*pgconn.FallbackConfig{&pgconn.FallbackConfig{ + newFallbacks := append([]*pgconn.FallbackConfig{{ Host: connConfig.Host, Port: connConfig.Port, TLSConfig: connConfig.TLSConfig, diff --git a/vendor/github.com/jackc/pgx/v5/tx.go b/vendor/github.com/jackc/pgx/v5/tx.go index e57142a..8feeb51 100644 --- a/vendor/github.com/jackc/pgx/v5/tx.go +++ b/vendor/github.com/jackc/pgx/v5/tx.go @@ -44,6 +44,10 @@ type TxOptions struct { IsoLevel TxIsoLevel AccessMode TxAccessMode DeferrableMode TxDeferrableMode + + // BeginQuery is the SQL query that will be executed to begin the transaction. This allows using non-standard syntax + // such as BEGIN PRIORITY HIGH with CockroachDB. If set this will override the other settings. + BeginQuery string } var emptyTxOptions TxOptions @@ -53,6 +57,10 @@ func (txOptions TxOptions) beginSQL() string { return "begin" } + if txOptions.BeginQuery != "" { + return txOptions.BeginQuery + } + var buf strings.Builder buf.Grow(64) // 64 - maximum length of string with available options buf.WriteString("begin") @@ -144,7 +152,6 @@ type Tx interface { // called on the dbTx. type dbTx struct { conn *Conn - err error savepointNum int64 closed bool } diff --git a/vendor/gorm.io/driver/postgres/.gitignore b/vendor/gorm.io/driver/postgres/.gitignore new file mode 100644 index 0000000..485dee6 --- /dev/null +++ b/vendor/gorm.io/driver/postgres/.gitignore @@ -0,0 +1 @@ +.idea diff --git a/vendor/gorm.io/driver/postgres/error_translator.go b/vendor/gorm.io/driver/postgres/error_translator.go new file mode 100644 index 0000000..9c0ef25 --- /dev/null +++ b/vendor/gorm.io/driver/postgres/error_translator.go @@ -0,0 +1,48 @@ +package postgres + +import ( + "encoding/json" + + "gorm.io/gorm" + + "github.com/jackc/pgx/v5/pgconn" +) + +var errCodes = map[string]error{ + "23505": gorm.ErrDuplicatedKey, + "23503": gorm.ErrForeignKeyViolated, + "42703": gorm.ErrInvalidField, +} + +type ErrMessage struct { + Code string + Severity string + Message string +} + +// Translate it will translate the error to native gorm errors. +// Since currently gorm supporting both pgx and pg drivers, only checking for pgx PgError types is not enough for translating errors, so we have additional error json marshal fallback. +func (dialector Dialector) Translate(err error) error { + if pgErr, ok := err.(*pgconn.PgError); ok { + if translatedErr, found := errCodes[pgErr.Code]; found { + return translatedErr + } + return err + } + + parsedErr, marshalErr := json.Marshal(err) + if marshalErr != nil { + return err + } + + var errMsg ErrMessage + unmarshalErr := json.Unmarshal(parsedErr, &errMsg) + if unmarshalErr != nil { + return err + } + + if translatedErr, found := errCodes[errMsg.Code]; found { + return translatedErr + } + return err +} diff --git a/vendor/gorm.io/driver/postgres/migrator.go b/vendor/gorm.io/driver/postgres/migrator.go index 3dec4a7..6174e1c 100644 --- a/vendor/gorm.io/driver/postgres/migrator.go +++ b/vendor/gorm.io/driver/postgres/migrator.go @@ -13,44 +13,61 @@ import ( "gorm.io/gorm/schema" ) +// See https://stackoverflow.com/questions/2204058/list-columns-with-indexes-in-postgresql +// Here are some changes: +// - use `LEFT JOIN` instead of `CROSS JOIN` +// - exclude indexes used to support constraints (they are auto-generated) const indexSql = ` -select - t.relname as table_name, - i.relname as index_name, - a.attname as column_name, - ix.indisunique as non_unique, - ix.indisprimary as primary -from - pg_class t, - pg_class i, - pg_index ix, - pg_attribute a -where - t.oid = ix.indrelid - and i.oid = ix.indexrelid - and a.attrelid = t.oid - and a.attnum = ANY(ix.indkey) - and t.relkind = 'r' - and t.relname = ? +SELECT + ct.relname AS table_name, + ci.relname AS index_name, + i.indisunique AS non_unique, + i.indisprimary AS primary, + a.attname AS column_name +FROM + pg_index i + LEFT JOIN pg_class ct ON ct.oid = i.indrelid + LEFT JOIN pg_class ci ON ci.oid = i.indexrelid + LEFT JOIN pg_attribute a ON a.attrelid = ct.oid + LEFT JOIN pg_constraint con ON con.conindid = i.indexrelid +WHERE + a.attnum = ANY(i.indkey) + AND con.oid IS NULL + AND ct.relkind = 'r' + AND ct.relname = ? ` var typeAliasMap = map[string][]string{ - "int2": {"smallint"}, - "int4": {"integer"}, - "int8": {"bigint"}, - "smallint": {"int2"}, - "integer": {"int4"}, - "bigint": {"int8"}, - "decimal": {"numeric"}, - "numeric": {"decimal"}, + "int2": {"smallint"}, + "int4": {"integer"}, + "int8": {"bigint"}, + "smallint": {"int2"}, + "integer": {"int4"}, + "bigint": {"int8"}, + "decimal": {"numeric"}, + "numeric": {"decimal"}, + "timestamptz": {"timestamp with time zone"}, + "timestamp with time zone": {"timestamptz"}, + "bool": {"boolean"}, + "boolean": {"bool"}, } type Migrator struct { migrator.Migrator } +// select querys ignore dryrun +func (m Migrator) queryRaw(sql string, values ...interface{}) (tx *gorm.DB) { + queryTx := m.DB + if m.DB.DryRun { + queryTx = m.DB.Session(&gorm.Session{}) + queryTx.DryRun = false + } + return queryTx.Raw(sql, values...) +} + func (m Migrator) CurrentDatabase() (name string) { - m.DB.Raw("SELECT CURRENT_DATABASE()").Scan(&name) + m.queryRaw("SELECT CURRENT_DATABASE()").Scan(&name) return } @@ -76,11 +93,13 @@ func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statem func (m Migrator) HasIndex(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { - if idx := stmt.Schema.LookIndex(name); idx != nil { - name = idx.Name + if stmt.Schema != nil { + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } } currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table) - return m.DB.Raw( + return m.queryRaw( "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ? AND schemaname = ?", curTable, name, currentSchema, ).Scan(&count).Error }) @@ -90,33 +109,35 @@ func (m Migrator) HasIndex(value interface{}, name string) bool { func (m Migrator) CreateIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if idx := stmt.Schema.LookIndex(name); idx != nil { - opts := m.BuildIndexOptions(idx.Fields, stmt) - values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts} + if stmt.Schema != nil { + if idx := stmt.Schema.LookIndex(name); idx != nil { + opts := m.BuildIndexOptions(idx.Fields, stmt) + values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts} - createIndexSQL := "CREATE " - if idx.Class != "" { - createIndexSQL += idx.Class + " " + createIndexSQL := "CREATE " + if idx.Class != "" { + createIndexSQL += idx.Class + " " + } + createIndexSQL += "INDEX " + + if strings.TrimSpace(strings.ToUpper(idx.Option)) == "CONCURRENTLY" { + createIndexSQL += "CONCURRENTLY " + } + + createIndexSQL += "IF NOT EXISTS ? ON ?" + + if idx.Type != "" { + createIndexSQL += " USING " + idx.Type + "(?)" + } else { + createIndexSQL += " ?" + } + + if idx.Where != "" { + createIndexSQL += " WHERE " + idx.Where + } + + return m.DB.Exec(createIndexSQL, values...).Error } - createIndexSQL += "INDEX " - - if strings.TrimSpace(strings.ToUpper(idx.Option)) == "CONCURRENTLY" { - createIndexSQL += "CONCURRENTLY " - } - - createIndexSQL += "IF NOT EXISTS ? ON ?" - - if idx.Type != "" { - createIndexSQL += " USING " + idx.Type + "(?)" - } else { - createIndexSQL += " ?" - } - - if idx.Where != "" { - createIndexSQL += " WHERE " + idx.Where - } - - return m.DB.Exec(createIndexSQL, values...).Error } return fmt.Errorf("failed to create index with name %v", name) @@ -134,8 +155,10 @@ func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error func (m Migrator) DropIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if idx := stmt.Schema.LookIndex(name); idx != nil { - name = idx.Name + if stmt.Schema != nil { + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } } return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error @@ -144,7 +167,7 @@ func (m Migrator) DropIndex(value interface{}, name string) error { func (m Migrator) GetTables() (tableList []string, err error) { currentSchema, _ := m.CurrentSchema(m.DB.Statement, "") - return tableList, m.DB.Raw("SELECT table_name FROM information_schema.tables WHERE table_schema = ? AND table_type = ?", currentSchema, "BASE TABLE").Scan(&tableList).Error + return tableList, m.queryRaw("SELECT table_name FROM information_schema.tables WHERE table_schema = ? AND table_type = ?", currentSchema, "BASE TABLE").Scan(&tableList).Error } func (m Migrator) CreateTable(values ...interface{}) (err error) { @@ -153,13 +176,16 @@ func (m Migrator) CreateTable(values ...interface{}) (err error) { } for _, value := range m.ReorderModels(values, false) { if err = m.RunWithValue(value, func(stmt *gorm.Statement) error { - for _, field := range stmt.Schema.FieldsByDBName { - if field.Comment != "" { - if err := m.DB.Exec( - "COMMENT ON COLUMN ?.? IS ?", - m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)), - ).Error; err != nil { - return err + if stmt.Schema != nil { + for _, fieldName := range stmt.Schema.DBNames { + field := stmt.Schema.FieldsByDBName[fieldName] + if field.Comment != "" { + if err := m.DB.Exec( + "COMMENT ON COLUMN ?.? IS ?", + m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)), + ).Error; err != nil { + return err + } } } } @@ -175,7 +201,7 @@ func (m Migrator) HasTable(value interface{}) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table) - return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentSchema, curTable, "BASE TABLE").Scan(&count).Error + return m.queryRaw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentSchema, curTable, "BASE TABLE").Scan(&count).Error }) return count > 0 } @@ -200,13 +226,15 @@ func (m Migrator) AddColumn(value interface{}, field string) error { m.resetPreparedStmts() return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if field := stmt.Schema.LookUpField(field); field != nil { - if field.Comment != "" { - if err := m.DB.Exec( - "COMMENT ON COLUMN ?.? IS ?", - m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)), - ).Error; err != nil { - return err + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(field); field != nil { + if field.Comment != "" { + if err := m.DB.Exec( + "COMMENT ON COLUMN ?.? IS ?", + m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)), + ).Error; err != nil { + return err + } } } } @@ -225,7 +253,7 @@ func (m Migrator) HasColumn(value interface{}, field string) bool { } currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table) - return m.DB.Raw( + return m.queryRaw( "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentSchema, curTable, name, ).Scan(&count).Error @@ -250,7 +278,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy checkSQL += "WHERE objsubid = (SELECT ordinal_position FROM information_schema.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?) " checkSQL += "AND objoid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = ? AND relnamespace = " checkSQL += "(SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = ?))" - m.DB.Raw(checkSQL, values...).Scan(&description) + m.queryRaw(checkSQL, values...).Scan(&description) comment := strings.Trim(field.Comment, "'") comment = strings.Trim(comment, `"`) @@ -269,101 +297,92 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy // AlterColumn alter value's `field` column' type based on schema definition func (m Migrator) AlterColumn(value interface{}, field string) error { err := m.RunWithValue(value, func(stmt *gorm.Statement) error { - if field := stmt.Schema.LookUpField(field); field != nil { - var ( - columnTypes, _ = m.DB.Migrator().ColumnTypes(value) - fieldColumnType *migrator.ColumnType - ) - for _, columnType := range columnTypes { - if columnType.Name() == field.DBName { - fieldColumnType, _ = columnType.(*migrator.ColumnType) - } - } - - fileType := clause.Expr{SQL: m.DataTypeOf(field)} - // check for typeName and SQL name - isSameType := true - if fieldColumnType.DatabaseTypeName() != fileType.SQL { - isSameType = false - // if different, also check for aliases - aliases := m.GetTypeAliases(fieldColumnType.DatabaseTypeName()) - for _, alias := range aliases { - if strings.HasPrefix(fileType.SQL, alias) { - isSameType = true - break + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(field); field != nil { + var ( + columnTypes, _ = m.DB.Migrator().ColumnTypes(value) + fieldColumnType *migrator.ColumnType + ) + for _, columnType := range columnTypes { + if columnType.Name() == field.DBName { + fieldColumnType, _ = columnType.(*migrator.ColumnType) } } - } - // not same, migrate - if !isSameType { - filedColumnAutoIncrement, _ := fieldColumnType.AutoIncrement() - if field.AutoIncrement && filedColumnAutoIncrement { // update - serialDatabaseType, _ := getSerialDatabaseType(fileType.SQL) - if t, _ := fieldColumnType.ColumnType(); t != serialDatabaseType { - if err := m.UpdateSequence(m.DB, stmt, field, serialDatabaseType); err != nil { - return err + fileType := clause.Expr{SQL: m.DataTypeOf(field)} + // check for typeName and SQL name + isSameType := true + if fieldColumnType.DatabaseTypeName() != fileType.SQL { + isSameType = false + // if different, also check for aliases + aliases := m.GetTypeAliases(fieldColumnType.DatabaseTypeName()) + for _, alias := range aliases { + if strings.HasPrefix(fileType.SQL, alias) { + isSameType = true + break } } - } else if field.AutoIncrement && !filedColumnAutoIncrement { // create - serialDatabaseType, _ := getSerialDatabaseType(fileType.SQL) - if err := m.CreateSequence(m.DB, stmt, field, serialDatabaseType); err != nil { - return err - } - } else if !field.AutoIncrement && filedColumnAutoIncrement { // delete - if err := m.DeleteSequence(m.DB, stmt, field, fileType); err != nil { - return err - } - } else { - if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? TYPE ? USING ?::?", - m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType, clause.Column{Name: field.DBName}, fileType).Error; err != nil { - return err - } } - } - if null, _ := fieldColumnType.Nullable(); null == field.NotNull { - if field.NotNull { - if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET NOT NULL", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil { - return err - } - } else { - if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP NOT NULL", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil { - return err - } - } - } - - if uniq, _ := fieldColumnType.Unique(); !uniq && field.Unique { - idxName := clause.Column{Name: m.DB.Config.NamingStrategy.IndexName(stmt.Table, field.DBName)} - // Not a unique constraint but a unique index - if !m.HasIndex(stmt.Table, idxName.Name) { - if err := m.DB.Exec("ALTER TABLE ? ADD CONSTRAINT ? UNIQUE(?)", m.CurrentTable(stmt), idxName, clause.Column{Name: field.DBName}).Error; err != nil { - return err - } - } - } - - if v, ok := fieldColumnType.DefaultValue(); (field.DefaultValueInterface == nil && ok) || v != field.DefaultValue { - if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") { - if field.DefaultValueInterface != nil { - defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}} - m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface) - if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface)}).Error; err != nil { + // not same, migrate + if !isSameType { + filedColumnAutoIncrement, _ := fieldColumnType.AutoIncrement() + if field.AutoIncrement && filedColumnAutoIncrement { // update + serialDatabaseType, _ := getSerialDatabaseType(fileType.SQL) + if t, _ := fieldColumnType.ColumnType(); t != serialDatabaseType { + if err := m.UpdateSequence(m.DB, stmt, field, serialDatabaseType); err != nil { + return err + } + } + } else if field.AutoIncrement && !filedColumnAutoIncrement { // create + serialDatabaseType, _ := getSerialDatabaseType(fileType.SQL) + if err := m.CreateSequence(m.DB, stmt, field, serialDatabaseType); err != nil { return err } - } else if field.DefaultValue != "(-)" { - if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DefaultValue}).Error; err != nil { + } else if !field.AutoIncrement && filedColumnAutoIncrement { // delete + if err := m.DeleteSequence(m.DB, stmt, field, fileType); err != nil { return err } } else { - if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DefaultValue}).Error; err != nil { + if err := m.modifyColumn(stmt, field, fileType, fieldColumnType); err != nil { return err } } } + + if null, _ := fieldColumnType.Nullable(); null == field.NotNull { + if field.NotNull { + if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET NOT NULL", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil { + return err + } + } else { + if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP NOT NULL", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil { + return err + } + } + } + + if v, ok := fieldColumnType.DefaultValue(); (field.DefaultValueInterface == nil && ok) || v != field.DefaultValue { + if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") { + if field.DefaultValueInterface != nil { + defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}} + m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface) + if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface)}).Error; err != nil { + return err + } + } else if field.DefaultValue != "(-)" { + if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DefaultValue}).Error; err != nil { + return err + } + } else { + if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DefaultValue}).Error; err != nil { + return err + } + } + } + } + return nil } - return nil } return fmt.Errorf("failed to look up field with name: %s", field) }) @@ -375,18 +394,39 @@ func (m Migrator) AlterColumn(value interface{}, field string) error { return nil } +func (m Migrator) modifyColumn(stmt *gorm.Statement, field *schema.Field, targetType clause.Expr, existingColumn *migrator.ColumnType) error { + alterSQL := "ALTER TABLE ? ALTER COLUMN ? TYPE ? USING ?::?" + isUncastableDefaultValue := false + + if targetType.SQL == "boolean" { + switch existingColumn.DatabaseTypeName() { + case "int2", "int8", "numeric": + alterSQL = "ALTER TABLE ? ALTER COLUMN ? TYPE ? USING ?::int::?" + } + isUncastableDefaultValue = true + } + + if dv, _ := existingColumn.DefaultValue(); dv != "" && isUncastableDefaultValue { + if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil { + return err + } + } + if err := m.DB.Exec(alterSQL, m.CurrentTable(stmt), clause.Column{Name: field.DBName}, targetType, clause.Column{Name: field.DBName}, targetType).Error; err != nil { + return err + } + return nil +} + func (m Migrator) HasConstraint(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { - constraint, chk, table := m.GuessConstraintAndTable(stmt, name) - currentSchema, curTable := m.CurrentSchema(stmt, table) + constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) if constraint != nil { - name = constraint.Name - } else if chk != nil { - name = chk.Name + name = constraint.GetName() } + currentSchema, curTable := m.CurrentSchema(stmt, table) - return m.DB.Raw( + return m.queryRaw( "SELECT count(*) FROM INFORMATION_SCHEMA.table_constraints WHERE table_schema = ? AND table_name = ? AND constraint_name = ?", currentSchema, curTable, name, ).Scan(&count).Error @@ -401,7 +441,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, var ( currentDatabase = m.DB.Migrator().CurrentDatabase() currentSchema, table = m.CurrentSchema(stmt, stmt.Table) - columns, err = m.DB.Raw( + columns, err = m.queryRaw( "SELECT c.column_name, c.is_nullable = 'YES', c.udt_name, c.character_maximum_length, c.numeric_precision, c.numeric_precision_radix, c.numeric_scale, c.datetime_precision, 8 * typlen, c.column_default, pd.description, c.identity_increment FROM information_schema.columns AS c JOIN pg_type AS pgt ON c.udt_name = pgt.typname LEFT JOIN pg_catalog.pg_description as pd ON pd.objsubid = c.ordinal_position AND pd.objoid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = c.table_name AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = c.table_schema)) where table_catalog = ? AND table_schema = ? AND table_name = ?", currentDatabase, currentSchema, table).Rows() ) @@ -441,7 +481,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, } if column.DefaultValueValue.Valid { - column.DefaultValueValue.String = regexp.MustCompile(`'?(.*)\b'?:+[\w\s]+$`).ReplaceAllString(column.DefaultValueValue.String, "$1") + column.DefaultValueValue.String = parseDefaultValueValue(column.DefaultValueValue.String) } if datetimePrecision.Valid { @@ -475,7 +515,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, // check primary, unique field { - columnTypeRows, err := m.DB.Raw("SELECT constraint_name FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ? AND constraint_type = ?", currentDatabase, currentSchema, table, "UNIQUE").Rows() + columnTypeRows, err := m.queryRaw("SELECT constraint_name FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_catalog, table_name, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ? AND constraint_type = ?", currentDatabase, currentSchema, table, "UNIQUE").Rows() if err != nil { return err } @@ -487,7 +527,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, } columnTypeRows.Close() - columnTypeRows, err = m.DB.Raw("SELECT c.column_name, constraint_name, constraint_type FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ?", currentDatabase, currentSchema, table).Rows() + columnTypeRows, err = m.queryRaw("SELECT c.column_name, constraint_name, constraint_type FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_catalog, table_name, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ?", currentDatabase, currentSchema, table).Rows() if err != nil { return err } @@ -514,7 +554,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, // check column type { - dataTypeRows, err := m.DB.Raw(`SELECT a.attname as column_name, format_type(a.atttypid, a.atttypmod) AS data_type + dataTypeRows, err := m.queryRaw(`SELECT a.attname as column_name, format_type(a.atttypid, a.atttypmod) AS data_type FROM pg_attribute a JOIN pg_class b ON a.attrelid = b.oid AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = ?) WHERE a.attnum > 0 -- hide internal columns AND NOT a.attisdropped -- hide deleted columns @@ -672,7 +712,7 @@ func (m Migrator) GetIndexes(value interface{}) ([]gorm.Index, error) { err := m.RunWithValue(value, func(stmt *gorm.Statement) error { result := make([]*Index, 0) - scanErr := m.DB.Raw(indexSql, stmt.Table).Scan(&result).Error + scanErr := m.queryRaw(indexSql, stmt.Table).Scan(&result).Error if scanErr != nil { return scanErr } @@ -747,3 +787,8 @@ func (m Migrator) RenameColumn(dst interface{}, oldName, field string) error { m.resetPreparedStmts() return nil } + +func parseDefaultValueValue(defaultValue string) string { + value := regexp.MustCompile(`^(.*?)(?:::.*)?$`).ReplaceAllString(defaultValue, "$1") + return strings.Trim(value, "'") +} diff --git a/vendor/gorm.io/driver/postgres/postgres.go b/vendor/gorm.io/driver/postgres/postgres.go index 5d73d27..e865b0f 100644 --- a/vendor/gorm.io/driver/postgres/postgres.go +++ b/vendor/gorm.io/driver/postgres/postgres.go @@ -24,11 +24,17 @@ type Dialector struct { type Config struct { DriverName string DSN string + WithoutQuotingCheck bool PreferSimpleProtocol bool WithoutReturning bool Conn gorm.ConnPool } +var ( + timeZoneMatcher = regexp.MustCompile("(time_zone|TimeZone)=(.*?)($|&| )") + defaultIdentifierLength = 63 //maximum identifier length for postgres +) + func Open(dsn string) gorm.Dialector { return &Dialector{&Config{DSN: dsn}} } @@ -41,17 +47,42 @@ func (dialector Dialector) Name() string { return "postgres" } -var timeZoneMatcher = regexp.MustCompile("(time_zone|TimeZone)=(.*?)($|&| )") +func (dialector Dialector) Apply(config *gorm.Config) error { + if config.NamingStrategy == nil { + config.NamingStrategy = schema.NamingStrategy{ + IdentifierMaxLength: defaultIdentifierLength, + } + return nil + } + + switch v := config.NamingStrategy.(type) { + case *schema.NamingStrategy: + if v.IdentifierMaxLength <= 0 { + v.IdentifierMaxLength = defaultIdentifierLength + } + case schema.NamingStrategy: + if v.IdentifierMaxLength <= 0 { + v.IdentifierMaxLength = defaultIdentifierLength + config.NamingStrategy = v + } + } + + return nil +} func (dialector Dialector) Initialize(db *gorm.DB) (err error) { + callbackConfig := &callbacks.Config{ + CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT"}, + UpdateClauses: []string{"UPDATE", "SET", "FROM", "WHERE"}, + DeleteClauses: []string{"DELETE", "FROM", "WHERE"}, + } // register callbacks if !dialector.WithoutReturning { - callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ - CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"}, - UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"}, - DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"}, - }) + callbackConfig.CreateClauses = append(callbackConfig.CreateClauses, "RETURNING") + callbackConfig.UpdateClauses = append(callbackConfig.UpdateClauses, "RETURNING") + callbackConfig.DeleteClauses = append(callbackConfig.DeleteClauses, "RETURNING") } + callbacks.RegisterDefaultCallbacks(db, callbackConfig) if dialector.Conn != nil { db.ConnPool = dialector.Conn @@ -90,10 +121,23 @@ func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { writer.WriteByte('$') - writer.WriteString(strconv.Itoa(len(stmt.Vars))) + index := 0 + varLen := len(stmt.Vars) + if varLen > 0 { + switch stmt.Vars[0].(type) { + case pgx.QueryExecMode: + index++ + } + } + writer.WriteString(strconv.Itoa(varLen - index)) } func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { + if dialector.WithoutQuotingCheck { + writer.WriteString(str) + return + } + var ( underQuoted, selfQuoted bool continuousBacktick int8 diff --git a/vendor/modules.txt b/vendor/modules.txt index 89a07ef..70eb6c0 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -83,8 +83,6 @@ github.com/go-playground/universal-translator # github.com/go-playground/validator/v10 v10.25.0 ## explicit; go 1.20 github.com/go-playground/validator/v10 -# github.com/golang-jwt/jwt v3.2.2+incompatible -## explicit # github.com/golang-jwt/jwt/v5 v5.2.1 ## explicit; go 1.18 github.com/golang-jwt/jwt/v5 @@ -109,16 +107,16 @@ github.com/jackc/pgpassfile # github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 ## explicit; go 1.14 github.com/jackc/pgservicefile -# github.com/jackc/pgx/v5 v5.2.0 -## explicit; go 1.18 +# github.com/jackc/pgx/v5 v5.4.3 +## explicit; go 1.19 github.com/jackc/pgx/v5 github.com/jackc/pgx/v5/internal/anynil github.com/jackc/pgx/v5/internal/iobufpool -github.com/jackc/pgx/v5/internal/nbconn github.com/jackc/pgx/v5/internal/pgio github.com/jackc/pgx/v5/internal/sanitize github.com/jackc/pgx/v5/internal/stmtcache github.com/jackc/pgx/v5/pgconn +github.com/jackc/pgx/v5/pgconn/internal/bgreader github.com/jackc/pgx/v5/pgconn/internal/ctxwatch github.com/jackc/pgx/v5/pgproto3 github.com/jackc/pgx/v5/pgtype @@ -284,8 +282,8 @@ gopkg.in/ini.v1 # gopkg.in/yaml.v3 v3.0.1 ## explicit gopkg.in/yaml.v3 -# gorm.io/driver/postgres v1.4.7 -## explicit; go 1.14 +# gorm.io/driver/postgres v1.5.7 +## explicit; go 1.18 gorm.io/driver/postgres # gorm.io/gorm v1.25.12 ## explicit; go 1.18