本站源代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1257 lines
34KB

  1. // Copyright 2015 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package xorm
  5. import (
  6. "database/sql/driver"
  7. "fmt"
  8. "reflect"
  9. "strings"
  10. "time"
  11. "xorm.io/builder"
  12. "xorm.io/core"
  13. )
  14. // Statement save all the sql info for executing SQL
  15. type Statement struct {
  16. RefTable *core.Table
  17. Engine *Engine
  18. Start int
  19. LimitN int
  20. idParam *core.PK
  21. OrderStr string
  22. JoinStr string
  23. joinArgs []interface{}
  24. GroupByStr string
  25. HavingStr string
  26. ColumnStr string
  27. selectStr string
  28. useAllCols bool
  29. OmitStr string
  30. AltTableName string
  31. tableName string
  32. RawSQL string
  33. RawParams []interface{}
  34. UseCascade bool
  35. UseAutoJoin bool
  36. StoreEngine string
  37. Charset string
  38. UseCache bool
  39. UseAutoTime bool
  40. noAutoCondition bool
  41. IsDistinct bool
  42. IsForUpdate bool
  43. TableAlias string
  44. allUseBool bool
  45. checkVersion bool
  46. unscoped bool
  47. columnMap columnMap
  48. omitColumnMap columnMap
  49. mustColumnMap map[string]bool
  50. nullableMap map[string]bool
  51. incrColumns exprParams
  52. decrColumns exprParams
  53. exprColumns exprParams
  54. cond builder.Cond
  55. bufferSize int
  56. context ContextCache
  57. lastError error
  58. }
  59. // Init reset all the statement's fields
  60. func (statement *Statement) Init() {
  61. statement.RefTable = nil
  62. statement.Start = 0
  63. statement.LimitN = 0
  64. statement.OrderStr = ""
  65. statement.UseCascade = true
  66. statement.JoinStr = ""
  67. statement.joinArgs = make([]interface{}, 0)
  68. statement.GroupByStr = ""
  69. statement.HavingStr = ""
  70. statement.ColumnStr = ""
  71. statement.OmitStr = ""
  72. statement.columnMap = columnMap{}
  73. statement.omitColumnMap = columnMap{}
  74. statement.AltTableName = ""
  75. statement.tableName = ""
  76. statement.idParam = nil
  77. statement.RawSQL = ""
  78. statement.RawParams = make([]interface{}, 0)
  79. statement.UseCache = true
  80. statement.UseAutoTime = true
  81. statement.noAutoCondition = false
  82. statement.IsDistinct = false
  83. statement.IsForUpdate = false
  84. statement.TableAlias = ""
  85. statement.selectStr = ""
  86. statement.allUseBool = false
  87. statement.useAllCols = false
  88. statement.mustColumnMap = make(map[string]bool)
  89. statement.nullableMap = make(map[string]bool)
  90. statement.checkVersion = true
  91. statement.unscoped = false
  92. statement.incrColumns = exprParams{}
  93. statement.decrColumns = exprParams{}
  94. statement.exprColumns = exprParams{}
  95. statement.cond = builder.NewCond()
  96. statement.bufferSize = 0
  97. statement.context = nil
  98. statement.lastError = nil
  99. }
  100. // NoAutoCondition if you do not want convert bean's field as query condition, then use this function
  101. func (statement *Statement) NoAutoCondition(no ...bool) *Statement {
  102. statement.noAutoCondition = true
  103. if len(no) > 0 {
  104. statement.noAutoCondition = no[0]
  105. }
  106. return statement
  107. }
  108. // Alias set the table alias
  109. func (statement *Statement) Alias(alias string) *Statement {
  110. statement.TableAlias = alias
  111. return statement
  112. }
  113. // SQL adds raw sql statement
  114. func (statement *Statement) SQL(query interface{}, args ...interface{}) *Statement {
  115. switch query.(type) {
  116. case (*builder.Builder):
  117. var err error
  118. statement.RawSQL, statement.RawParams, err = query.(*builder.Builder).ToSQL()
  119. if err != nil {
  120. statement.lastError = err
  121. }
  122. case string:
  123. statement.RawSQL = query.(string)
  124. statement.RawParams = args
  125. default:
  126. statement.lastError = ErrUnSupportedSQLType
  127. }
  128. return statement
  129. }
  130. // Where add Where statement
  131. func (statement *Statement) Where(query interface{}, args ...interface{}) *Statement {
  132. return statement.And(query, args...)
  133. }
  134. // And add Where & and statement
  135. func (statement *Statement) And(query interface{}, args ...interface{}) *Statement {
  136. switch query.(type) {
  137. case string:
  138. cond := builder.Expr(query.(string), args...)
  139. statement.cond = statement.cond.And(cond)
  140. case map[string]interface{}:
  141. queryMap := query.(map[string]interface{})
  142. newMap := make(map[string]interface{})
  143. for k, v := range queryMap {
  144. newMap[statement.Engine.Quote(k)] = v
  145. }
  146. statement.cond = statement.cond.And(builder.Eq(newMap))
  147. case builder.Cond:
  148. cond := query.(builder.Cond)
  149. statement.cond = statement.cond.And(cond)
  150. for _, v := range args {
  151. if vv, ok := v.(builder.Cond); ok {
  152. statement.cond = statement.cond.And(vv)
  153. }
  154. }
  155. default:
  156. statement.lastError = ErrConditionType
  157. }
  158. return statement
  159. }
  160. // Or add Where & Or statement
  161. func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement {
  162. switch query.(type) {
  163. case string:
  164. cond := builder.Expr(query.(string), args...)
  165. statement.cond = statement.cond.Or(cond)
  166. case map[string]interface{}:
  167. cond := builder.Eq(query.(map[string]interface{}))
  168. statement.cond = statement.cond.Or(cond)
  169. case builder.Cond:
  170. cond := query.(builder.Cond)
  171. statement.cond = statement.cond.Or(cond)
  172. for _, v := range args {
  173. if vv, ok := v.(builder.Cond); ok {
  174. statement.cond = statement.cond.Or(vv)
  175. }
  176. }
  177. default:
  178. // TODO: not support condition type
  179. }
  180. return statement
  181. }
  182. // In generate "Where column IN (?) " statement
  183. func (statement *Statement) In(column string, args ...interface{}) *Statement {
  184. in := builder.In(statement.Engine.Quote(column), args...)
  185. statement.cond = statement.cond.And(in)
  186. return statement
  187. }
  188. // NotIn generate "Where column NOT IN (?) " statement
  189. func (statement *Statement) NotIn(column string, args ...interface{}) *Statement {
  190. notIn := builder.NotIn(statement.Engine.Quote(column), args...)
  191. statement.cond = statement.cond.And(notIn)
  192. return statement
  193. }
  194. func (statement *Statement) setRefValue(v reflect.Value) error {
  195. var err error
  196. statement.RefTable, err = statement.Engine.autoMapType(reflect.Indirect(v))
  197. if err != nil {
  198. return err
  199. }
  200. statement.tableName = statement.Engine.TableName(v, true)
  201. return nil
  202. }
  203. func (statement *Statement) setRefBean(bean interface{}) error {
  204. var err error
  205. statement.RefTable, err = statement.Engine.autoMapType(rValue(bean))
  206. if err != nil {
  207. return err
  208. }
  209. statement.tableName = statement.Engine.TableName(bean, true)
  210. return nil
  211. }
  212. // Auto generating update columnes and values according a struct
  213. func (statement *Statement) buildUpdates(bean interface{},
  214. includeVersion, includeUpdated, includeNil,
  215. includeAutoIncr, update bool) ([]string, []interface{}) {
  216. engine := statement.Engine
  217. table := statement.RefTable
  218. allUseBool := statement.allUseBool
  219. useAllCols := statement.useAllCols
  220. mustColumnMap := statement.mustColumnMap
  221. nullableMap := statement.nullableMap
  222. columnMap := statement.columnMap
  223. omitColumnMap := statement.omitColumnMap
  224. unscoped := statement.unscoped
  225. var colNames = make([]string, 0)
  226. var args = make([]interface{}, 0)
  227. for _, col := range table.Columns() {
  228. if !includeVersion && col.IsVersion {
  229. continue
  230. }
  231. if col.IsCreated {
  232. continue
  233. }
  234. if !includeUpdated && col.IsUpdated {
  235. continue
  236. }
  237. if !includeAutoIncr && col.IsAutoIncrement {
  238. continue
  239. }
  240. if col.IsDeleted && !unscoped {
  241. continue
  242. }
  243. if omitColumnMap.contain(col.Name) {
  244. continue
  245. }
  246. if len(columnMap) > 0 && !columnMap.contain(col.Name) {
  247. continue
  248. }
  249. if col.MapType == core.ONLYFROMDB {
  250. continue
  251. }
  252. if statement.incrColumns.isColExist(col.Name) {
  253. continue
  254. } else if statement.decrColumns.isColExist(col.Name) {
  255. continue
  256. } else if statement.exprColumns.isColExist(col.Name) {
  257. continue
  258. }
  259. fieldValuePtr, err := col.ValueOf(bean)
  260. if err != nil {
  261. engine.logger.Error(err)
  262. continue
  263. }
  264. fieldValue := *fieldValuePtr
  265. fieldType := reflect.TypeOf(fieldValue.Interface())
  266. if fieldType == nil {
  267. continue
  268. }
  269. requiredField := useAllCols
  270. includeNil := useAllCols
  271. if b, ok := getFlagForColumn(mustColumnMap, col); ok {
  272. if b {
  273. requiredField = true
  274. } else {
  275. continue
  276. }
  277. }
  278. // !evalphobia! set fieldValue as nil when column is nullable and zero-value
  279. if b, ok := getFlagForColumn(nullableMap, col); ok {
  280. if b && col.Nullable && isZero(fieldValue.Interface()) {
  281. var nilValue *int
  282. fieldValue = reflect.ValueOf(nilValue)
  283. fieldType = reflect.TypeOf(fieldValue.Interface())
  284. includeNil = true
  285. }
  286. }
  287. var val interface{}
  288. if fieldValue.CanAddr() {
  289. if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok {
  290. data, err := structConvert.ToDB()
  291. if err != nil {
  292. engine.logger.Error(err)
  293. } else {
  294. val = data
  295. }
  296. goto APPEND
  297. }
  298. }
  299. if structConvert, ok := fieldValue.Interface().(core.Conversion); ok {
  300. data, err := structConvert.ToDB()
  301. if err != nil {
  302. engine.logger.Error(err)
  303. } else {
  304. val = data
  305. }
  306. goto APPEND
  307. }
  308. if fieldType.Kind() == reflect.Ptr {
  309. if fieldValue.IsNil() {
  310. if includeNil {
  311. args = append(args, nil)
  312. colNames = append(colNames, fmt.Sprintf("%v=?", engine.Quote(col.Name)))
  313. }
  314. continue
  315. } else if !fieldValue.IsValid() {
  316. continue
  317. } else {
  318. // dereference ptr type to instance type
  319. fieldValue = fieldValue.Elem()
  320. fieldType = reflect.TypeOf(fieldValue.Interface())
  321. requiredField = true
  322. }
  323. }
  324. switch fieldType.Kind() {
  325. case reflect.Bool:
  326. if allUseBool || requiredField {
  327. val = fieldValue.Interface()
  328. } else {
  329. // if a bool in a struct, it will not be as a condition because it default is false,
  330. // please use Where() instead
  331. continue
  332. }
  333. case reflect.String:
  334. if !requiredField && fieldValue.String() == "" {
  335. continue
  336. }
  337. // for MyString, should convert to string or panic
  338. if fieldType.String() != reflect.String.String() {
  339. val = fieldValue.String()
  340. } else {
  341. val = fieldValue.Interface()
  342. }
  343. case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64:
  344. if !requiredField && fieldValue.Int() == 0 {
  345. continue
  346. }
  347. val = fieldValue.Interface()
  348. case reflect.Float32, reflect.Float64:
  349. if !requiredField && fieldValue.Float() == 0.0 {
  350. continue
  351. }
  352. val = fieldValue.Interface()
  353. case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64:
  354. if !requiredField && fieldValue.Uint() == 0 {
  355. continue
  356. }
  357. t := int64(fieldValue.Uint())
  358. val = reflect.ValueOf(&t).Interface()
  359. case reflect.Struct:
  360. if fieldType.ConvertibleTo(core.TimeType) {
  361. t := fieldValue.Convert(core.TimeType).Interface().(time.Time)
  362. if !requiredField && (t.IsZero() || !fieldValue.IsValid()) {
  363. continue
  364. }
  365. val = engine.formatColTime(col, t)
  366. } else if nulType, ok := fieldValue.Interface().(driver.Valuer); ok {
  367. val, _ = nulType.Value()
  368. } else {
  369. if !col.SQLType.IsJson() {
  370. engine.autoMapType(fieldValue)
  371. if table, ok := engine.Tables[fieldValue.Type()]; ok {
  372. if len(table.PrimaryKeys) == 1 {
  373. pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
  374. // fix non-int pk issues
  375. if pkField.IsValid() && (!requiredField && !isZero(pkField.Interface())) {
  376. val = pkField.Interface()
  377. } else {
  378. continue
  379. }
  380. } else {
  381. // TODO: how to handler?
  382. panic("not supported")
  383. }
  384. } else {
  385. val = fieldValue.Interface()
  386. }
  387. } else {
  388. // Blank struct could not be as update data
  389. if requiredField || !isStructZero(fieldValue) {
  390. bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface())
  391. if err != nil {
  392. panic(fmt.Sprintf("mashal %v failed", fieldValue.Interface()))
  393. }
  394. if col.SQLType.IsText() {
  395. val = string(bytes)
  396. } else if col.SQLType.IsBlob() {
  397. val = bytes
  398. }
  399. } else {
  400. continue
  401. }
  402. }
  403. }
  404. case reflect.Array, reflect.Slice, reflect.Map:
  405. if !requiredField {
  406. if fieldValue == reflect.Zero(fieldType) {
  407. continue
  408. }
  409. if fieldType.Kind() == reflect.Array {
  410. if isArrayValueZero(fieldValue) {
  411. continue
  412. }
  413. } else if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 {
  414. continue
  415. }
  416. }
  417. if col.SQLType.IsText() {
  418. bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface())
  419. if err != nil {
  420. engine.logger.Error(err)
  421. continue
  422. }
  423. val = string(bytes)
  424. } else if col.SQLType.IsBlob() {
  425. var bytes []byte
  426. var err error
  427. if fieldType.Kind() == reflect.Slice &&
  428. fieldType.Elem().Kind() == reflect.Uint8 {
  429. if fieldValue.Len() > 0 {
  430. val = fieldValue.Bytes()
  431. } else {
  432. continue
  433. }
  434. } else if fieldType.Kind() == reflect.Array &&
  435. fieldType.Elem().Kind() == reflect.Uint8 {
  436. val = fieldValue.Slice(0, 0).Interface()
  437. } else {
  438. bytes, err = DefaultJSONHandler.Marshal(fieldValue.Interface())
  439. if err != nil {
  440. engine.logger.Error(err)
  441. continue
  442. }
  443. val = bytes
  444. }
  445. } else {
  446. continue
  447. }
  448. default:
  449. val = fieldValue.Interface()
  450. }
  451. APPEND:
  452. args = append(args, val)
  453. if col.IsPrimaryKey && engine.dialect.DBType() == "ql" {
  454. continue
  455. }
  456. colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name)))
  457. }
  458. return colNames, args
  459. }
  460. func (statement *Statement) needTableName() bool {
  461. return len(statement.JoinStr) > 0
  462. }
  463. func (statement *Statement) colName(col *core.Column, tableName string) string {
  464. if statement.needTableName() {
  465. var nm = tableName
  466. if len(statement.TableAlias) > 0 {
  467. nm = statement.TableAlias
  468. }
  469. return statement.Engine.Quote(nm) + "." + statement.Engine.Quote(col.Name)
  470. }
  471. return statement.Engine.Quote(col.Name)
  472. }
  473. // TableName return current tableName
  474. func (statement *Statement) TableName() string {
  475. if statement.AltTableName != "" {
  476. return statement.AltTableName
  477. }
  478. return statement.tableName
  479. }
  480. // ID generate "where id = ? " statement or for composite key "where key1 = ? and key2 = ?"
  481. func (statement *Statement) ID(id interface{}) *Statement {
  482. idValue := reflect.ValueOf(id)
  483. idType := reflect.TypeOf(idValue.Interface())
  484. switch idType {
  485. case ptrPkType:
  486. if pkPtr, ok := (id).(*core.PK); ok {
  487. statement.idParam = pkPtr
  488. return statement
  489. }
  490. case pkType:
  491. if pk, ok := (id).(core.PK); ok {
  492. statement.idParam = &pk
  493. return statement
  494. }
  495. }
  496. switch idType.Kind() {
  497. case reflect.String:
  498. statement.idParam = &core.PK{idValue.Convert(reflect.TypeOf("")).Interface()}
  499. return statement
  500. }
  501. statement.idParam = &core.PK{id}
  502. return statement
  503. }
  504. // Incr Generate "Update ... Set column = column + arg" statement
  505. func (statement *Statement) Incr(column string, arg ...interface{}) *Statement {
  506. if len(arg) > 0 {
  507. statement.incrColumns.addParam(column, arg[0])
  508. } else {
  509. statement.incrColumns.addParam(column, 1)
  510. }
  511. return statement
  512. }
  513. // Decr Generate "Update ... Set column = column - arg" statement
  514. func (statement *Statement) Decr(column string, arg ...interface{}) *Statement {
  515. if len(arg) > 0 {
  516. statement.decrColumns.addParam(column, arg[0])
  517. } else {
  518. statement.decrColumns.addParam(column, 1)
  519. }
  520. return statement
  521. }
  522. // SetExpr Generate "Update ... Set column = {expression}" statement
  523. func (statement *Statement) SetExpr(column string, expression interface{}) *Statement {
  524. statement.exprColumns.addParam(column, expression)
  525. return statement
  526. }
  527. func (statement *Statement) col2NewColsWithQuote(columns ...string) []string {
  528. newColumns := make([]string, 0)
  529. quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")
  530. for _, col := range columns {
  531. newColumns = append(newColumns, statement.Engine.Quote(eraseAny(col, quotes...)))
  532. }
  533. return newColumns
  534. }
  535. func (statement *Statement) colmap2NewColsWithQuote() []string {
  536. newColumns := make([]string, len(statement.columnMap), len(statement.columnMap))
  537. copy(newColumns, statement.columnMap)
  538. for i := 0; i < len(statement.columnMap); i++ {
  539. newColumns[i] = statement.Engine.Quote(newColumns[i])
  540. }
  541. return newColumns
  542. }
  543. // Distinct generates "DISTINCT col1, col2 " statement
  544. func (statement *Statement) Distinct(columns ...string) *Statement {
  545. statement.IsDistinct = true
  546. statement.Cols(columns...)
  547. return statement
  548. }
  549. // ForUpdate generates "SELECT ... FOR UPDATE" statement
  550. func (statement *Statement) ForUpdate() *Statement {
  551. statement.IsForUpdate = true
  552. return statement
  553. }
  554. // Select replace select
  555. func (statement *Statement) Select(str string) *Statement {
  556. statement.selectStr = str
  557. return statement
  558. }
  559. // Cols generate "col1, col2" statement
  560. func (statement *Statement) Cols(columns ...string) *Statement {
  561. cols := col2NewCols(columns...)
  562. for _, nc := range cols {
  563. statement.columnMap.add(nc)
  564. }
  565. newColumns := statement.colmap2NewColsWithQuote()
  566. statement.ColumnStr = strings.Join(newColumns, ", ")
  567. statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.quote("*"), "*", -1)
  568. return statement
  569. }
  570. // AllCols update use only: update all columns
  571. func (statement *Statement) AllCols() *Statement {
  572. statement.useAllCols = true
  573. return statement
  574. }
  575. // MustCols update use only: must update columns
  576. func (statement *Statement) MustCols(columns ...string) *Statement {
  577. newColumns := col2NewCols(columns...)
  578. for _, nc := range newColumns {
  579. statement.mustColumnMap[strings.ToLower(nc)] = true
  580. }
  581. return statement
  582. }
  583. // UseBool indicates that use bool fields as update contents and query contiditions
  584. func (statement *Statement) UseBool(columns ...string) *Statement {
  585. if len(columns) > 0 {
  586. statement.MustCols(columns...)
  587. } else {
  588. statement.allUseBool = true
  589. }
  590. return statement
  591. }
  592. // Omit do not use the columns
  593. func (statement *Statement) Omit(columns ...string) {
  594. newColumns := col2NewCols(columns...)
  595. for _, nc := range newColumns {
  596. statement.omitColumnMap = append(statement.omitColumnMap, nc)
  597. }
  598. statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", ")))
  599. }
  600. // Nullable Update use only: update columns to null when value is nullable and zero-value
  601. func (statement *Statement) Nullable(columns ...string) {
  602. newColumns := col2NewCols(columns...)
  603. for _, nc := range newColumns {
  604. statement.nullableMap[strings.ToLower(nc)] = true
  605. }
  606. }
  607. // Top generate LIMIT limit statement
  608. func (statement *Statement) Top(limit int) *Statement {
  609. statement.Limit(limit)
  610. return statement
  611. }
  612. // Limit generate LIMIT start, limit statement
  613. func (statement *Statement) Limit(limit int, start ...int) *Statement {
  614. statement.LimitN = limit
  615. if len(start) > 0 {
  616. statement.Start = start[0]
  617. }
  618. return statement
  619. }
  620. // OrderBy generate "Order By order" statement
  621. func (statement *Statement) OrderBy(order string) *Statement {
  622. if len(statement.OrderStr) > 0 {
  623. statement.OrderStr += ", "
  624. }
  625. statement.OrderStr += order
  626. return statement
  627. }
  628. // Desc generate `ORDER BY xx DESC`
  629. func (statement *Statement) Desc(colNames ...string) *Statement {
  630. var buf strings.Builder
  631. if len(statement.OrderStr) > 0 {
  632. fmt.Fprint(&buf, statement.OrderStr, ", ")
  633. }
  634. newColNames := statement.col2NewColsWithQuote(colNames...)
  635. fmt.Fprintf(&buf, "%v DESC", strings.Join(newColNames, " DESC, "))
  636. statement.OrderStr = buf.String()
  637. return statement
  638. }
  639. // Asc provide asc order by query condition, the input parameters are columns.
  640. func (statement *Statement) Asc(colNames ...string) *Statement {
  641. var buf strings.Builder
  642. if len(statement.OrderStr) > 0 {
  643. fmt.Fprint(&buf, statement.OrderStr, ", ")
  644. }
  645. newColNames := statement.col2NewColsWithQuote(colNames...)
  646. fmt.Fprintf(&buf, "%v ASC", strings.Join(newColNames, " ASC, "))
  647. statement.OrderStr = buf.String()
  648. return statement
  649. }
  650. // Table tempororily set table name, the parameter could be a string or a pointer of struct
  651. func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
  652. v := rValue(tableNameOrBean)
  653. t := v.Type()
  654. if t.Kind() == reflect.Struct {
  655. var err error
  656. statement.RefTable, err = statement.Engine.autoMapType(v)
  657. if err != nil {
  658. statement.Engine.logger.Error(err)
  659. return statement
  660. }
  661. }
  662. statement.AltTableName = statement.Engine.TableName(tableNameOrBean, true)
  663. return statement
  664. }
  665. // Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
  666. func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement {
  667. var buf strings.Builder
  668. if len(statement.JoinStr) > 0 {
  669. fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP)
  670. } else {
  671. fmt.Fprintf(&buf, "%v JOIN ", joinOP)
  672. }
  673. switch tp := tablename.(type) {
  674. case builder.Builder:
  675. subSQL, subQueryArgs, err := tp.ToSQL()
  676. if err != nil {
  677. statement.lastError = err
  678. return statement
  679. }
  680. tbs := strings.Split(tp.TableName(), ".")
  681. quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")
  682. var aliasName = strings.Trim(tbs[len(tbs)-1], strings.Join(quotes, ""))
  683. fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
  684. statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
  685. case *builder.Builder:
  686. subSQL, subQueryArgs, err := tp.ToSQL()
  687. if err != nil {
  688. statement.lastError = err
  689. return statement
  690. }
  691. tbs := strings.Split(tp.TableName(), ".")
  692. quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")
  693. var aliasName = strings.Trim(tbs[len(tbs)-1], strings.Join(quotes, ""))
  694. fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
  695. statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
  696. default:
  697. tbName := statement.Engine.TableName(tablename, true)
  698. fmt.Fprintf(&buf, "%s ON %v", tbName, condition)
  699. }
  700. statement.JoinStr = buf.String()
  701. statement.joinArgs = append(statement.joinArgs, args...)
  702. return statement
  703. }
  704. // GroupBy generate "Group By keys" statement
  705. func (statement *Statement) GroupBy(keys string) *Statement {
  706. statement.GroupByStr = keys
  707. return statement
  708. }
  709. // Having generate "Having conditions" statement
  710. func (statement *Statement) Having(conditions string) *Statement {
  711. statement.HavingStr = fmt.Sprintf("HAVING %v", conditions)
  712. return statement
  713. }
  714. // Unscoped always disable struct tag "deleted"
  715. func (statement *Statement) Unscoped() *Statement {
  716. statement.unscoped = true
  717. return statement
  718. }
  719. func (statement *Statement) genColumnStr() string {
  720. if statement.RefTable == nil {
  721. return ""
  722. }
  723. var buf strings.Builder
  724. columns := statement.RefTable.Columns()
  725. for _, col := range columns {
  726. if statement.omitColumnMap.contain(col.Name) {
  727. continue
  728. }
  729. if len(statement.columnMap) > 0 && !statement.columnMap.contain(col.Name) {
  730. continue
  731. }
  732. if col.MapType == core.ONLYTODB {
  733. continue
  734. }
  735. if buf.Len() != 0 {
  736. buf.WriteString(", ")
  737. }
  738. if statement.JoinStr != "" {
  739. if statement.TableAlias != "" {
  740. buf.WriteString(statement.TableAlias)
  741. } else {
  742. buf.WriteString(statement.TableName())
  743. }
  744. buf.WriteString(".")
  745. }
  746. statement.Engine.QuoteTo(&buf, col.Name)
  747. }
  748. return buf.String()
  749. }
  750. func (statement *Statement) genCreateTableSQL() string {
  751. return statement.Engine.dialect.CreateTableSql(statement.RefTable, statement.TableName(),
  752. statement.StoreEngine, statement.Charset)
  753. }
  754. func (statement *Statement) genIndexSQL() []string {
  755. var sqls []string
  756. tbName := statement.TableName()
  757. for _, index := range statement.RefTable.Indexes {
  758. if index.Type == core.IndexType {
  759. sql := statement.Engine.dialect.CreateIndexSql(tbName, index)
  760. /*idxTBName := strings.Replace(tbName, ".", "_", -1)
  761. idxTBName = strings.Replace(idxTBName, `"`, "", -1)
  762. sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(idxTBName, idxName)),
  763. quote(tbName), quote(strings.Join(index.Cols, quote(","))))*/
  764. sqls = append(sqls, sql)
  765. }
  766. }
  767. return sqls
  768. }
  769. func uniqueName(tableName, uqeName string) string {
  770. return fmt.Sprintf("UQE_%v_%v", tableName, uqeName)
  771. }
  772. func (statement *Statement) genUniqueSQL() []string {
  773. var sqls []string
  774. tbName := statement.TableName()
  775. for _, index := range statement.RefTable.Indexes {
  776. if index.Type == core.UniqueType {
  777. sql := statement.Engine.dialect.CreateIndexSql(tbName, index)
  778. sqls = append(sqls, sql)
  779. }
  780. }
  781. return sqls
  782. }
  783. func (statement *Statement) genDelIndexSQL() []string {
  784. var sqls []string
  785. tbName := statement.TableName()
  786. idxPrefixName := strings.Replace(tbName, `"`, "", -1)
  787. idxPrefixName = strings.Replace(idxPrefixName, `.`, "_", -1)
  788. for idxName, index := range statement.RefTable.Indexes {
  789. var rIdxName string
  790. if index.Type == core.UniqueType {
  791. rIdxName = uniqueName(idxPrefixName, idxName)
  792. } else if index.Type == core.IndexType {
  793. rIdxName = indexName(idxPrefixName, idxName)
  794. }
  795. sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(statement.Engine.TableName(rIdxName, true)))
  796. if statement.Engine.dialect.IndexOnTable() {
  797. sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(tbName))
  798. }
  799. sqls = append(sqls, sql)
  800. }
  801. return sqls
  802. }
  803. func (statement *Statement) genAddColumnStr(col *core.Column) (string, []interface{}) {
  804. quote := statement.Engine.Quote
  805. sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quote(statement.TableName()),
  806. col.String(statement.Engine.dialect))
  807. if statement.Engine.dialect.DBType() == core.MYSQL && len(col.Comment) > 0 {
  808. sql += " COMMENT '" + col.Comment + "'"
  809. }
  810. sql += ";"
  811. return sql, []interface{}{}
  812. }
  813. func (statement *Statement) buildConds(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) {
  814. return statement.Engine.buildConds(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols,
  815. statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName)
  816. }
  817. func (statement *Statement) mergeConds(bean interface{}) error {
  818. if !statement.noAutoCondition {
  819. var addedTableName = (len(statement.JoinStr) > 0)
  820. autoCond, err := statement.buildConds(statement.RefTable, bean, true, true, false, true, addedTableName)
  821. if err != nil {
  822. return err
  823. }
  824. statement.cond = statement.cond.And(autoCond)
  825. }
  826. if err := statement.processIDParam(); err != nil {
  827. return err
  828. }
  829. return nil
  830. }
  831. func (statement *Statement) genConds(bean interface{}) (string, []interface{}, error) {
  832. if err := statement.mergeConds(bean); err != nil {
  833. return "", nil, err
  834. }
  835. return builder.ToSQL(statement.cond)
  836. }
  837. func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, error) {
  838. v := rValue(bean)
  839. isStruct := v.Kind() == reflect.Struct
  840. if isStruct {
  841. statement.setRefBean(bean)
  842. }
  843. var columnStr = statement.ColumnStr
  844. if len(statement.selectStr) > 0 {
  845. columnStr = statement.selectStr
  846. } else {
  847. // TODO: always generate column names, not use * even if join
  848. if len(statement.JoinStr) == 0 {
  849. if len(columnStr) == 0 {
  850. if len(statement.GroupByStr) > 0 {
  851. columnStr = statement.Engine.quoteColumns(statement.GroupByStr)
  852. } else {
  853. columnStr = statement.genColumnStr()
  854. }
  855. }
  856. } else {
  857. if len(columnStr) == 0 {
  858. if len(statement.GroupByStr) > 0 {
  859. columnStr = statement.Engine.quoteColumns(statement.GroupByStr)
  860. }
  861. }
  862. }
  863. }
  864. if len(columnStr) == 0 {
  865. columnStr = "*"
  866. }
  867. if isStruct {
  868. if err := statement.mergeConds(bean); err != nil {
  869. return "", nil, err
  870. }
  871. } else {
  872. if err := statement.processIDParam(); err != nil {
  873. return "", nil, err
  874. }
  875. }
  876. condSQL, condArgs, err := builder.ToSQL(statement.cond)
  877. if err != nil {
  878. return "", nil, err
  879. }
  880. sqlStr, err := statement.genSelectSQL(columnStr, condSQL, true, true)
  881. if err != nil {
  882. return "", nil, err
  883. }
  884. return sqlStr, append(statement.joinArgs, condArgs...), nil
  885. }
  886. func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interface{}, error) {
  887. var condSQL string
  888. var condArgs []interface{}
  889. var err error
  890. if len(beans) > 0 {
  891. statement.setRefBean(beans[0])
  892. condSQL, condArgs, err = statement.genConds(beans[0])
  893. } else {
  894. condSQL, condArgs, err = builder.ToSQL(statement.cond)
  895. }
  896. if err != nil {
  897. return "", nil, err
  898. }
  899. var selectSQL = statement.selectStr
  900. if len(selectSQL) <= 0 {
  901. if statement.IsDistinct {
  902. selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr)
  903. } else {
  904. selectSQL = "count(*)"
  905. }
  906. }
  907. sqlStr, err := statement.genSelectSQL(selectSQL, condSQL, false, false)
  908. if err != nil {
  909. return "", nil, err
  910. }
  911. return sqlStr, append(statement.joinArgs, condArgs...), nil
  912. }
  913. func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) {
  914. statement.setRefBean(bean)
  915. var sumStrs = make([]string, 0, len(columns))
  916. for _, colName := range columns {
  917. if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") {
  918. colName = statement.Engine.Quote(colName)
  919. }
  920. sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName))
  921. }
  922. sumSelect := strings.Join(sumStrs, ", ")
  923. condSQL, condArgs, err := statement.genConds(bean)
  924. if err != nil {
  925. return "", nil, err
  926. }
  927. sqlStr, err := statement.genSelectSQL(sumSelect, condSQL, true, true)
  928. if err != nil {
  929. return "", nil, err
  930. }
  931. return sqlStr, append(statement.joinArgs, condArgs...), nil
  932. }
  933. func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, needOrderBy bool) (string, error) {
  934. var (
  935. distinct string
  936. dialect = statement.Engine.Dialect()
  937. quote = statement.Engine.Quote
  938. fromStr = " FROM "
  939. top, mssqlCondi, whereStr string
  940. )
  941. if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
  942. distinct = "DISTINCT "
  943. }
  944. if len(condSQL) > 0 {
  945. whereStr = " WHERE " + condSQL
  946. }
  947. if dialect.DBType() == core.MSSQL && strings.Contains(statement.TableName(), "..") {
  948. fromStr += statement.TableName()
  949. } else {
  950. fromStr += quote(statement.TableName())
  951. }
  952. if statement.TableAlias != "" {
  953. if dialect.DBType() == core.ORACLE {
  954. fromStr += " " + quote(statement.TableAlias)
  955. } else {
  956. fromStr += " AS " + quote(statement.TableAlias)
  957. }
  958. }
  959. if statement.JoinStr != "" {
  960. fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr)
  961. }
  962. if dialect.DBType() == core.MSSQL {
  963. if statement.LimitN > 0 {
  964. top = fmt.Sprintf("TOP %d ", statement.LimitN)
  965. }
  966. if statement.Start > 0 {
  967. var column string
  968. if len(statement.RefTable.PKColumns()) == 0 {
  969. for _, index := range statement.RefTable.Indexes {
  970. if len(index.Cols) == 1 {
  971. column = index.Cols[0]
  972. break
  973. }
  974. }
  975. if len(column) == 0 {
  976. column = statement.RefTable.ColumnsSeq()[0]
  977. }
  978. } else {
  979. column = statement.RefTable.PKColumns()[0].Name
  980. }
  981. if statement.needTableName() {
  982. if len(statement.TableAlias) > 0 {
  983. column = statement.TableAlias + "." + column
  984. } else {
  985. column = statement.TableName() + "." + column
  986. }
  987. }
  988. var orderStr string
  989. if needOrderBy && len(statement.OrderStr) > 0 {
  990. orderStr = " ORDER BY " + statement.OrderStr
  991. }
  992. var groupStr string
  993. if len(statement.GroupByStr) > 0 {
  994. groupStr = " GROUP BY " + statement.GroupByStr
  995. }
  996. mssqlCondi = fmt.Sprintf("(%s NOT IN (SELECT TOP %d %s%s%s%s%s))",
  997. column, statement.Start, column, fromStr, whereStr, orderStr, groupStr)
  998. }
  999. }
  1000. var buf strings.Builder
  1001. fmt.Fprintf(&buf, "SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr)
  1002. if len(mssqlCondi) > 0 {
  1003. if len(whereStr) > 0 {
  1004. fmt.Fprint(&buf, " AND ", mssqlCondi)
  1005. } else {
  1006. fmt.Fprint(&buf, " WHERE ", mssqlCondi)
  1007. }
  1008. }
  1009. if statement.GroupByStr != "" {
  1010. fmt.Fprint(&buf, " GROUP BY ", statement.GroupByStr)
  1011. }
  1012. if statement.HavingStr != "" {
  1013. fmt.Fprint(&buf, " ", statement.HavingStr)
  1014. }
  1015. if needOrderBy && statement.OrderStr != "" {
  1016. fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr)
  1017. }
  1018. if needLimit {
  1019. if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE {
  1020. if statement.Start > 0 {
  1021. fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", statement.LimitN, statement.Start)
  1022. } else if statement.LimitN > 0 {
  1023. fmt.Fprint(&buf, " LIMIT ", statement.LimitN)
  1024. }
  1025. } else if dialect.DBType() == core.ORACLE {
  1026. if statement.Start != 0 || statement.LimitN != 0 {
  1027. oldString := buf.String()
  1028. buf.Reset()
  1029. rawColStr := columnStr
  1030. if rawColStr == "*" {
  1031. rawColStr = "at.*"
  1032. }
  1033. fmt.Fprintf(&buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d",
  1034. columnStr, rawColStr, oldString, statement.Start+statement.LimitN, statement.Start)
  1035. }
  1036. }
  1037. }
  1038. if statement.IsForUpdate {
  1039. return dialect.ForUpdateSql(buf.String()), nil
  1040. }
  1041. return buf.String(), nil
  1042. }
  1043. func (statement *Statement) processIDParam() error {
  1044. if statement.idParam == nil || statement.RefTable == nil {
  1045. return nil
  1046. }
  1047. if len(statement.RefTable.PrimaryKeys) != len(*statement.idParam) {
  1048. return fmt.Errorf("ID condition is error, expect %d primarykeys, there are %d",
  1049. len(statement.RefTable.PrimaryKeys),
  1050. len(*statement.idParam),
  1051. )
  1052. }
  1053. for i, col := range statement.RefTable.PKColumns() {
  1054. var colName = statement.colName(col, statement.TableName())
  1055. statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]})
  1056. }
  1057. return nil
  1058. }
  1059. func (statement *Statement) joinColumns(cols []*core.Column, includeTableName bool) string {
  1060. var colnames = make([]string, len(cols))
  1061. for i, col := range cols {
  1062. if includeTableName {
  1063. colnames[i] = statement.Engine.Quote(statement.TableName()) +
  1064. "." + statement.Engine.Quote(col.Name)
  1065. } else {
  1066. colnames[i] = statement.Engine.Quote(col.Name)
  1067. }
  1068. }
  1069. return strings.Join(colnames, ", ")
  1070. }
  1071. func (statement *Statement) convertIDSQL(sqlStr string) string {
  1072. if statement.RefTable != nil {
  1073. cols := statement.RefTable.PKColumns()
  1074. if len(cols) == 0 {
  1075. return ""
  1076. }
  1077. colstrs := statement.joinColumns(cols, false)
  1078. sqls := splitNNoCase(sqlStr, " from ", 2)
  1079. if len(sqls) != 2 {
  1080. return ""
  1081. }
  1082. var top string
  1083. if statement.LimitN > 0 && statement.Engine.dialect.DBType() == core.MSSQL {
  1084. top = fmt.Sprintf("TOP %d ", statement.LimitN)
  1085. }
  1086. newsql := fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1])
  1087. return newsql
  1088. }
  1089. return ""
  1090. }
  1091. func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) {
  1092. if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 {
  1093. return "", ""
  1094. }
  1095. colstrs := statement.joinColumns(statement.RefTable.PKColumns(), true)
  1096. sqls := splitNNoCase(sqlStr, "where", 2)
  1097. if len(sqls) != 2 {
  1098. if len(sqls) == 1 {
  1099. return sqls[0], fmt.Sprintf("SELECT %v FROM %v",
  1100. colstrs, statement.Engine.Quote(statement.TableName()))
  1101. }
  1102. return "", ""
  1103. }
  1104. var whereStr = sqls[1]
  1105. // TODO: for postgres only, if any other database?
  1106. var paraStr string
  1107. if statement.Engine.dialect.DBType() == core.POSTGRES {
  1108. paraStr = "$"
  1109. } else if statement.Engine.dialect.DBType() == core.MSSQL {
  1110. paraStr = ":"
  1111. }
  1112. if paraStr != "" {
  1113. if strings.Contains(sqls[1], paraStr) {
  1114. dollers := strings.Split(sqls[1], paraStr)
  1115. whereStr = dollers[0]
  1116. for i, c := range dollers[1:] {
  1117. ccs := strings.SplitN(c, " ", 2)
  1118. whereStr += fmt.Sprintf(paraStr+"%v %v", i+1, ccs[1])
  1119. }
  1120. }
  1121. }
  1122. return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v",
  1123. colstrs, statement.Engine.Quote(statement.TableName()),
  1124. whereStr)
  1125. }
上海开阖软件有限公司 沪ICP备12045867号-1