Skip to content

Commit f73cafc

Browse files
committed
make sure we set enable_load_extension back to false
1 parent fed9521 commit f73cafc

File tree

2 files changed

+30
-14
lines changed

2 files changed

+30
-14
lines changed

lib/sqlite3/database.rb

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,13 @@ module SQLite3
6666
# db.load_extension(SQLean::Crypto)
6767
#
6868
# It's also possible in v2.4.0+ to load extensions via the SQLite3::Database constructor by using
69-
# the +extensions:+ keyword argument to pass an array of strings or extension specifiers:
69+
# the +extensions:+ keyword argument to pass an array of String paths or extension specifiers:
7070
#
7171
# db = SQLite3::Database.new(":memory:", extensions: ["/path/to/extension", SQLean::Crypto])
7272
#
73-
# Note that the constructor will implicitly call #enable_load_extension if the +extensions:+
74-
# keyword argument is present.
73+
# Note that when loading extensions via the constructor, there is no need to call
74+
# #enable_load_extension; however it is still necessary to call #enable_load_extensions before any
75+
# subsequently invocations of #load_extension on the initialized Database object.
7576
#
7677
class Database
7778
attr_reader :collations
@@ -729,10 +730,14 @@ def initialize_extensions(extensions) # :nodoc:
729730
raise TypeError, "extensions must be an Array" unless extensions.is_a?(Array)
730731
return if extensions.empty?
731732

732-
enable_load_extension(true)
733+
begin
734+
enable_load_extension(true)
733735

734-
extensions.each do |extension|
735-
load_extension(extension)
736+
extensions.each do |extension|
737+
load_extension(extension)
738+
end
739+
ensure
740+
enable_load_extension(false)
736741
end
737742
end
738743

test/test_database.rb

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -701,25 +701,36 @@ def test_initialize_extensions_with_extensions_calls_enable_load_extension
701701
mock_database_load_extension_internal(db)
702702
class << db
703703
attr_accessor :enable_load_extension_called
704+
attr_reader :enable_load_extension_arg
704705

705-
def enable_load_extension(...)
706-
@enable_load_extension_called = true
706+
def reset_test
707+
@enable_load_extension_called = 0
708+
@enable_load_extension_arg = []
709+
end
710+
711+
def enable_load_extension(val)
712+
@enable_load_extension_called += 1
713+
@enable_load_extension_arg << val
707714
end
708715
end
709716

717+
db.reset_test
710718
db.initialize_extensions(nil)
711-
refute(db.enable_load_extension_called)
719+
assert_equal(0, db.enable_load_extension_called)
712720

721+
db.reset_test
713722
db.initialize_extensions([])
714-
refute(db.enable_load_extension_called)
723+
assert_equal(0, db.enable_load_extension_called)
715724

725+
db.reset_test
716726
db.initialize_extensions(["/path/to/extension"])
717-
assert(db.enable_load_extension_called)
718-
719-
db.enable_load_extension_called = false # reset
727+
assert_equal(2, db.enable_load_extension_called)
728+
assert_equal([true, false], db.enable_load_extension_arg)
720729

730+
db.reset_test
721731
db.initialize_extensions([FakeExtensionSpecifier])
722-
assert(db.enable_load_extension_called)
732+
assert_equal(2, db.enable_load_extension_called)
733+
assert_equal([true, false], db.enable_load_extension_arg)
723734
end
724735

725736
def test_initialize_extensions_object_is_an_extension_specifier

0 commit comments

Comments
 (0)