|
- package testfixtures
-
- import (
- "database/sql"
- "fmt"
- "strings"
- )
-
- // SQLServer is the helper for SQL Server for this package.
- // SQL Server >= 2008 is required.
- type SQLServer struct {
- baseHelper
-
- tables []string
- }
-
- func (h *SQLServer) init(db *sql.DB) error {
- var err error
-
- h.tables, err = h.tableNames(db)
- if err != nil {
- return err
- }
-
- return nil
- }
-
- func (*SQLServer) paramType() int {
- return paramTypeQuestion
- }
-
- func (*SQLServer) quoteKeyword(s string) string {
- parts := strings.Split(s, ".")
- for i, p := range parts {
- parts[i] = fmt.Sprintf(`[%s]`, p)
- }
- return strings.Join(parts, ".")
- }
-
- func (*SQLServer) databaseName(q queryable) (string, error) {
- var dbName string
- err := q.QueryRow("SELECT DB_NAME()").Scan(&dbName)
- return dbName, err
- }
-
- func (*SQLServer) tableNames(q queryable) ([]string, error) {
- rows, err := q.Query("SELECT table_schema + '.' + table_name FROM information_schema.tables")
- if err != nil {
- return nil, err
- }
- defer rows.Close()
-
- var tables []string
- for rows.Next() {
- var table string
- if err = rows.Scan(&table); err != nil {
- return nil, err
- }
- tables = append(tables, table)
- }
- if err = rows.Err(); err != nil {
- return nil, err
- }
- return tables, nil
- }
-
- func (h *SQLServer) tableHasIdentityColumn(q queryable, tableName string) bool {
- sql := `
- SELECT COUNT(*)
- FROM SYS.IDENTITY_COLUMNS
- WHERE OBJECT_ID = OBJECT_ID(?)
- `
- var count int
- q.QueryRow(sql, h.quoteKeyword(tableName)).Scan(&count)
- return count > 0
-
- }
-
- func (h *SQLServer) whileInsertOnTable(tx *sql.Tx, tableName string, fn func() error) (err error) {
- if h.tableHasIdentityColumn(tx, tableName) {
- defer func() {
- _, err2 := tx.Exec(fmt.Sprintf("SET IDENTITY_INSERT %s OFF", h.quoteKeyword(tableName)))
- if err2 != nil && err == nil {
- err = err2
- }
- }()
-
- _, err := tx.Exec(fmt.Sprintf("SET IDENTITY_INSERT %s ON", h.quoteKeyword(tableName)))
- if err != nil {
- return err
- }
- }
- return fn()
- }
-
- func (h *SQLServer) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) (err error) {
- // ensure the triggers are re-enable after all
- defer func() {
- var sql string
- for _, table := range h.tables {
- sql += fmt.Sprintf("ALTER TABLE %s WITH CHECK CHECK CONSTRAINT ALL;", h.quoteKeyword(table))
- }
- if _, err2 := db.Exec(sql); err2 != nil && err == nil {
- err = err2
- }
- }()
-
- var sql string
- for _, table := range h.tables {
- sql += fmt.Sprintf("ALTER TABLE %s NOCHECK CONSTRAINT ALL;", h.quoteKeyword(table))
- }
- if _, err := db.Exec(sql); err != nil {
- return err
- }
-
- tx, err := db.Begin()
- if err != nil {
- return err
- }
- defer tx.Rollback()
-
- if err = loadFn(tx); err != nil {
- return err
- }
-
- return tx.Commit()
- }
-
- // splitter is a batchSplitter interface implementation. We need it for
- // SQL Server because commands like a `CREATE SCHEMA...` and a `CREATE TABLE...`
- // could not be executed in the same batch.
- // See https://docs.microsoft.com/en-us/previous-versions/sql/sql-server-2008-r2/ms175502(v=sql.105)#rules-for-using-batches
- func (*SQLServer) splitter() []byte {
- return []byte("GO\n")
- }
|