本站源代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

172 satır
3.5KB

  1. package testfixtures
  2. import (
  3. "database/sql"
  4. "fmt"
  5. "strings"
  6. )
  7. // Oracle is the Oracle database helper for this package
  8. type Oracle struct {
  9. baseHelper
  10. enabledConstraints []oracleConstraint
  11. sequences []string
  12. }
  13. type oracleConstraint struct {
  14. tableName string
  15. constraintName string
  16. }
  17. func (h *Oracle) init(db *sql.DB) error {
  18. var err error
  19. h.enabledConstraints, err = h.getEnabledConstraints(db)
  20. if err != nil {
  21. return err
  22. }
  23. h.sequences, err = h.getSequences(db)
  24. if err != nil {
  25. return err
  26. }
  27. return nil
  28. }
  29. func (*Oracle) paramType() int {
  30. return paramTypeColon
  31. }
  32. func (*Oracle) quoteKeyword(str string) string {
  33. return fmt.Sprintf("\"%s\"", strings.ToUpper(str))
  34. }
  35. func (*Oracle) databaseName(q queryable) (string, error) {
  36. var dbName string
  37. err := q.QueryRow("SELECT user FROM DUAL").Scan(&dbName)
  38. return dbName, err
  39. }
  40. func (*Oracle) tableNames(q queryable) ([]string, error) {
  41. query := `
  42. SELECT TABLE_NAME
  43. FROM USER_TABLES
  44. `
  45. rows, err := q.Query(query)
  46. if err != nil {
  47. return nil, err
  48. }
  49. defer rows.Close()
  50. var tables []string
  51. for rows.Next() {
  52. var table string
  53. if err = rows.Scan(&table); err != nil {
  54. return nil, err
  55. }
  56. tables = append(tables, table)
  57. }
  58. if err = rows.Err(); err != nil {
  59. return nil, err
  60. }
  61. return tables, nil
  62. }
  63. func (*Oracle) getEnabledConstraints(q queryable) ([]oracleConstraint, error) {
  64. var constraints []oracleConstraint
  65. rows, err := q.Query(`
  66. SELECT table_name, constraint_name
  67. FROM user_constraints
  68. WHERE constraint_type = 'R'
  69. AND status = 'ENABLED'
  70. `)
  71. if err != nil {
  72. return nil, err
  73. }
  74. defer rows.Close()
  75. for rows.Next() {
  76. var constraint oracleConstraint
  77. rows.Scan(&constraint.tableName, &constraint.constraintName)
  78. constraints = append(constraints, constraint)
  79. }
  80. if err = rows.Err(); err != nil {
  81. return nil, err
  82. }
  83. return constraints, nil
  84. }
  85. func (*Oracle) getSequences(q queryable) ([]string, error) {
  86. var sequences []string
  87. rows, err := q.Query("SELECT sequence_name FROM user_sequences")
  88. if err != nil {
  89. return nil, err
  90. }
  91. defer rows.Close()
  92. for rows.Next() {
  93. var sequence string
  94. if err = rows.Scan(&sequence); err != nil {
  95. return nil, err
  96. }
  97. sequences = append(sequences, sequence)
  98. }
  99. if err = rows.Err(); err != nil {
  100. return nil, err
  101. }
  102. return sequences, nil
  103. }
  104. func (h *Oracle) resetSequences(q queryable) error {
  105. for _, sequence := range h.sequences {
  106. _, err := q.Exec(fmt.Sprintf("DROP SEQUENCE %s", h.quoteKeyword(sequence)))
  107. if err != nil {
  108. return err
  109. }
  110. _, err = q.Exec(fmt.Sprintf("CREATE SEQUENCE %s START WITH %d", h.quoteKeyword(sequence), resetSequencesTo))
  111. if err != nil {
  112. return err
  113. }
  114. }
  115. return nil
  116. }
  117. func (h *Oracle) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) (err error) {
  118. // re-enable after load
  119. defer func() {
  120. for _, c := range h.enabledConstraints {
  121. _, err2 := db.Exec(fmt.Sprintf("ALTER TABLE %s ENABLE CONSTRAINT %s", h.quoteKeyword(c.tableName), h.quoteKeyword(c.constraintName)))
  122. if err2 != nil && err == nil {
  123. err = err2
  124. }
  125. }
  126. }()
  127. // disable foreign keys
  128. for _, c := range h.enabledConstraints {
  129. _, err := db.Exec(fmt.Sprintf("ALTER TABLE %s DISABLE CONSTRAINT %s", h.quoteKeyword(c.tableName), h.quoteKeyword(c.constraintName)))
  130. if err != nil {
  131. return err
  132. }
  133. }
  134. tx, err := db.Begin()
  135. if err != nil {
  136. return err
  137. }
  138. defer tx.Rollback()
  139. if err = loadFn(tx); err != nil {
  140. return err
  141. }
  142. if err = tx.Commit(); err != nil {
  143. return err
  144. }
  145. return h.resetSequences(db)
  146. }
上海开阖软件有限公司 沪ICP备12045867号-1