diff --git a/codecov.yml b/codecov.yml index c14f855b9f..9e1ee94add 100644 --- a/codecov.yml +++ b/codecov.yml @@ -32,3 +32,4 @@ ignore: - "*.xaml" - "*.xaml.cs" - "**/SampleData/*" + - "src/GitHub.App/sqlite-net/*" \ No newline at end of file diff --git a/src/GitHub.App/Models/Drafts/CommentDraft.cs b/src/GitHub.App/Models/Drafts/CommentDraft.cs new file mode 100644 index 0000000000..a55f23aab3 --- /dev/null +++ b/src/GitHub.App/Models/Drafts/CommentDraft.cs @@ -0,0 +1,15 @@ +using GitHub.ViewModels; + +namespace GitHub.Models.Drafts +{ + /// + /// Stores a draft for a + /// + public class CommentDraft + { + /// + /// Gets or sets the draft comment body. + /// + public string Body { get; set; } + } +} diff --git a/src/GitHub.App/Models/Drafts/PullRequestDraft.cs b/src/GitHub.App/Models/Drafts/PullRequestDraft.cs new file mode 100644 index 0000000000..fa359f4c13 --- /dev/null +++ b/src/GitHub.App/Models/Drafts/PullRequestDraft.cs @@ -0,0 +1,20 @@ +using GitHub.ViewModels.GitHubPane; + +namespace GitHub.Models.Drafts +{ + /// + /// Stores a draft for a . + /// + public class PullRequestDraft + { + /// + /// Gets or sets the draft pull request title. + /// + public string Title { get; set; } + + /// + /// Gets or sets the draft pull request body. + /// + public string Body { get; set; } + } +} diff --git a/src/GitHub.App/Models/Drafts/PullRequestReviewCommentDraft.cs b/src/GitHub.App/Models/Drafts/PullRequestReviewCommentDraft.cs new file mode 100644 index 0000000000..e9e29be412 --- /dev/null +++ b/src/GitHub.App/Models/Drafts/PullRequestReviewCommentDraft.cs @@ -0,0 +1,21 @@ +using System; +using GitHub.ViewModels; + +namespace GitHub.Models.Drafts +{ + /// + /// Stores a draft for a + /// + public class PullRequestReviewCommentDraft : CommentDraft + { + /// + /// Gets or sets the side of the diff that the draft comment was left on. + /// + public DiffSide Side { get; set; } + + /// + /// Gets or sets the time that the draft was last updated. + /// + public DateTimeOffset UpdatedAt { get; set; } + } +} diff --git a/src/GitHub.App/Models/Drafts/PullRequestReviewDraft.cs b/src/GitHub.App/Models/Drafts/PullRequestReviewDraft.cs new file mode 100644 index 0000000000..3dd3a891fb --- /dev/null +++ b/src/GitHub.App/Models/Drafts/PullRequestReviewDraft.cs @@ -0,0 +1,15 @@ +using GitHub.ViewModels.GitHubPane; + +namespace GitHub.Models.Drafts +{ + /// + /// Stores a draft for a . + /// + public class PullRequestReviewDraft + { + /// + /// Gets or sets the draft review body. + /// + public string Body { get; set; } + } +} diff --git a/src/GitHub.App/SampleData/CommentThreadViewModelDesigner.cs b/src/GitHub.App/SampleData/CommentThreadViewModelDesigner.cs index bc8f3955bb..cd282d81c7 100644 --- a/src/GitHub.App/SampleData/CommentThreadViewModelDesigner.cs +++ b/src/GitHub.App/SampleData/CommentThreadViewModelDesigner.cs @@ -14,8 +14,8 @@ public class CommentThreadViewModelDesigner : ViewModelBase, ICommentThreadViewM public IActorViewModel CurrentUser { get; set; } = new ActorViewModel { Login = "shana" }; - public Task DeleteComment(int pullRequestId, int commentId) => Task.CompletedTask; - public Task EditComment(string id, string body) => Task.CompletedTask; - public Task PostComment(string body) => Task.CompletedTask; + public Task DeleteComment(ICommentViewModel comment) => Task.CompletedTask; + public Task EditComment(ICommentViewModel comment) => Task.CompletedTask; + public Task PostComment(ICommentViewModel comment) => Task.CompletedTask; } } diff --git a/src/GitHub.InlineReviews/Services/InlineCommentPeekService.cs b/src/GitHub.App/Services/InlineCommentPeekService.cs similarity index 72% rename from src/GitHub.InlineReviews/Services/InlineCommentPeekService.cs rename to src/GitHub.App/Services/InlineCommentPeekService.cs index 892a7c47b1..463bfdef5e 100644 --- a/src/GitHub.InlineReviews/Services/InlineCommentPeekService.cs +++ b/src/GitHub.App/Services/InlineCommentPeekService.cs @@ -2,15 +2,9 @@ using System.ComponentModel.Composition; using System.Linq; using System.Reactive.Linq; -using System.Threading.Tasks; -using GitHub.Api; using GitHub.Extensions; -using GitHub.Factories; -using GitHub.InlineReviews.Peek; -using GitHub.InlineReviews.Tags; using GitHub.Models; -using GitHub.Primitives; -using GitHub.Services; +using GitHub.ViewModels; using Microsoft.VisualStudio.Language.Intellisense; using Microsoft.VisualStudio.Text; using Microsoft.VisualStudio.Text.Differencing; @@ -18,7 +12,7 @@ using Microsoft.VisualStudio.Text.Outlining; using Microsoft.VisualStudio.Text.Projection; -namespace GitHub.InlineReviews.Services +namespace GitHub.Services { /// /// Shows inline comments in a peek view. @@ -26,6 +20,7 @@ namespace GitHub.InlineReviews.Services [Export(typeof(IInlineCommentPeekService))] class InlineCommentPeekService : IInlineCommentPeekService { + const string relationship = "GitHubCodeReview"; readonly IOutliningManagerService outliningService; readonly IPeekBroker peekBroker; readonly IUsageTracker usageTracker; @@ -90,69 +85,46 @@ public void Hide(ITextView textView) } /// - public ITrackingPoint Show(ITextView textView, AddInlineCommentTag tag) + public ITrackingPoint Show(ITextView textView, DiffSide side, int lineNumber) { - Guard.ArgumentNotNull(tag, nameof(tag)); - - var lineAndtrackingPoint = GetLineAndTrackingPoint(textView, tag); + var lineAndtrackingPoint = GetLineAndTrackingPoint(textView, side, lineNumber); var line = lineAndtrackingPoint.Item1; var trackingPoint = lineAndtrackingPoint.Item2; var options = new PeekSessionCreationOptions( textView, - InlineCommentPeekRelationship.Instance.Name, + relationship, trackingPoint, defaultHeight: 0); ExpandCollapsedRegions(textView, line.Extent); var session = peekBroker.TriggerPeekSession(options); - var item = session.PeekableItems.OfType().FirstOrDefault(); - item?.ViewModel.Close.Take(1).Subscribe(_ => session.Dismiss()); - - return trackingPoint; - } - - /// - public ITrackingPoint Show(ITextView textView, ShowInlineCommentTag tag) - { - Guard.ArgumentNotNull(textView, nameof(textView)); - Guard.ArgumentNotNull(tag, nameof(tag)); - - var lineAndtrackingPoint = GetLineAndTrackingPoint(textView, tag); - var line = lineAndtrackingPoint.Item1; - var trackingPoint = lineAndtrackingPoint.Item2; - var options = new PeekSessionCreationOptions( - textView, - InlineCommentPeekRelationship.Instance.Name, - trackingPoint, - defaultHeight: 0); - - ExpandCollapsedRegions(textView, line.Extent); - - var session = peekBroker.TriggerPeekSession(options); - var item = session.PeekableItems.OfType().FirstOrDefault(); - item?.ViewModel.Close.Take(1).Subscribe(_ => session.Dismiss()); - + var item = session.PeekableItems.OfType().FirstOrDefault(); + item?.Closed.Take(1).Subscribe(_ => session.Dismiss()); + return trackingPoint; } - Tuple GetLineAndTrackingPoint(ITextView textView, InlineCommentTag tag) + Tuple GetLineAndTrackingPoint( + ITextView textView, + DiffSide side, + int lineNumber) { var diffModel = (textView as IWpfTextView)?.TextViewModel as IDifferenceTextViewModel; var snapshot = textView.TextSnapshot; if (diffModel?.ViewType == DifferenceViewType.InlineView) { - snapshot = tag.DiffChangeType == DiffChangeType.Delete ? + snapshot = side == DiffSide.Left ? diffModel.Viewer.DifferenceBuffer.LeftBuffer.CurrentSnapshot : diffModel.Viewer.DifferenceBuffer.RightBuffer.CurrentSnapshot; } - var line = snapshot.GetLineFromLineNumber(tag.LineNumber); + var line = snapshot.GetLineFromLineNumber(lineNumber); var trackingPoint = snapshot.CreateTrackingPoint(line.Start.Position, PointTrackingMode.Positive); ExpandCollapsedRegions(textView, line.Extent); - peekBroker.TriggerPeekSession(textView, trackingPoint, InlineCommentPeekRelationship.Instance.Name); + peekBroker.TriggerPeekSession(textView, trackingPoint, relationship); usageTracker.IncrementCounter(x => x.NumberOfPRReviewDiffViewInlineCommentOpen).Forget(); diff --git a/src/GitHub.App/Services/PullRequestEditorService.cs b/src/GitHub.App/Services/PullRequestEditorService.cs index 57b1dfff5e..3135c29730 100644 --- a/src/GitHub.App/Services/PullRequestEditorService.cs +++ b/src/GitHub.App/Services/PullRequestEditorService.cs @@ -5,9 +5,12 @@ using System.Linq; using System.Reactive.Linq; using System.Threading.Tasks; +using EnvDTE; using GitHub.Commands; using GitHub.Extensions; using GitHub.Models; +using GitHub.Models.Drafts; +using GitHub.ViewModels; using GitHub.VisualStudio; using Microsoft.VisualStudio; using Microsoft.VisualStudio.Editor; @@ -18,7 +21,6 @@ using Microsoft.VisualStudio.Text.Editor; using Microsoft.VisualStudio.Text.Projection; using Microsoft.VisualStudio.TextManager.Interop; -using EnvDTE; using Task = System.Threading.Tasks.Task; namespace GitHub.Services @@ -39,6 +41,8 @@ public class PullRequestEditorService : IPullRequestEditorService readonly IStatusBarNotificationService statusBar; readonly IGoToSolutionOrPullRequestFileCommand goToSolutionOrPullRequestFileCommand; readonly IEditorOptionsFactoryService editorOptionsFactoryService; + readonly IMessageDraftStore draftStore; + readonly IInlineCommentPeekService peekService; readonly IUsageTracker usageTracker; [ImportingConstructor] @@ -49,6 +53,8 @@ public PullRequestEditorService( IStatusBarNotificationService statusBar, IGoToSolutionOrPullRequestFileCommand goToSolutionOrPullRequestFileCommand, IEditorOptionsFactoryService editorOptionsFactoryService, + IMessageDraftStore draftStore, + IInlineCommentPeekService peekService, IUsageTracker usageTracker) { Guard.ArgumentNotNull(serviceProvider, nameof(serviceProvider)); @@ -58,6 +64,8 @@ public PullRequestEditorService( Guard.ArgumentNotNull(goToSolutionOrPullRequestFileCommand, nameof(goToSolutionOrPullRequestFileCommand)); Guard.ArgumentNotNull(goToSolutionOrPullRequestFileCommand, nameof(editorOptionsFactoryService)); Guard.ArgumentNotNull(usageTracker, nameof(usageTracker)); + Guard.ArgumentNotNull(peekService, nameof(peekService)); + Guard.ArgumentNotNull(draftStore, nameof(draftStore)); this.serviceProvider = serviceProvider; this.pullRequestService = pullRequestService; @@ -65,6 +73,8 @@ public PullRequestEditorService( this.statusBar = statusBar; this.goToSolutionOrPullRequestFileCommand = goToSolutionOrPullRequestFileCommand; this.editorOptionsFactoryService = editorOptionsFactoryService; + this.draftStore = draftStore; + this.peekService = peekService; this.usageTracker = usageTracker; } @@ -129,7 +139,7 @@ public async Task OpenFile( } /// - public async Task OpenDiff(IPullRequestSession session, string relativePath, string headSha, bool scrollToFirstDiff) + public async Task OpenDiff(IPullRequestSession session, string relativePath, string headSha, bool scrollToFirstDraftOrDiff) { Guard.ArgumentNotNull(session, nameof(session)); Guard.ArgumentNotEmptyString(relativePath, nameof(relativePath)); @@ -168,12 +178,37 @@ await pullRequestService.ExtractToTempFile( var caption = $"Diff - {Path.GetFileName(file.RelativePath)}"; var options = __VSDIFFSERVICEOPTIONS.VSDIFFOPT_DetectBinaryFiles | __VSDIFFSERVICEOPTIONS.VSDIFFOPT_LeftFileIsTemporary; + var openThread = (line: -1, side: DiffSide.Left); + var scrollToFirstDiff = false; if (!workingDirectory) { options |= __VSDIFFSERVICEOPTIONS.VSDIFFOPT_RightFileIsTemporary; } + if (scrollToFirstDraftOrDiff) + { + var (key, _) = PullRequestReviewCommentThreadViewModel.GetDraftKeys( + session.LocalRepository.CloneUrl.WithOwner(session.RepositoryOwner), + session.PullRequest.Number, + relativePath, + 0); + var drafts = (await draftStore.GetDrafts(key) + .ConfigureAwait(true)) + .OrderByDescending(x => x.data.UpdatedAt) + .ToList(); + + if (drafts.Count > 0 && int.TryParse(drafts[0].secondaryKey, out var line)) + { + openThread = (line, drafts[0].data.Side); + scrollToFirstDiff = false; + } + else + { + scrollToFirstDiff = true; + } + } + IVsWindowFrame frame; using (OpenWithOption(DifferenceViewerOptions.ScrollToFirstDiffName, scrollToFirstDiff)) using (OpenInProvisionalTab()) @@ -228,6 +263,18 @@ await pullRequestService.ExtractToTempFile( else await usageTracker.IncrementCounter(x => x.NumberOfPRDetailsViewChanges); + if (openThread.line != -1) + { + var view = diffViewer.ViewMode == DifferenceViewMode.Inline ? + diffViewer.InlineView : + openThread.side == DiffSide.Left ? diffViewer.LeftView : diffViewer.RightView; + + // HACK: We need to wait here for the view to initialize or the peek session won't appear. + // There must be a better way of doing this. + await Task.Delay(1500).ConfigureAwait(true); + peekService.Show(view, openThread.side, openThread.line); + } + return diffViewer; } catch (Exception e) @@ -247,7 +294,7 @@ public async Task OpenDiff( Guard.ArgumentNotEmptyString(relativePath, nameof(relativePath)); Guard.ArgumentNotNull(thread, nameof(thread)); - var diffViewer = await OpenDiff(session, relativePath, thread.CommitSha, scrollToFirstDiff: false); + var diffViewer = await OpenDiff(session, relativePath, thread.CommitSha, scrollToFirstDraftOrDiff: false); var param = (object)new InlineCommentNavigationParams { diff --git a/src/GitHub.App/Services/SqliteMessageDraftStore.cs b/src/GitHub.App/Services/SqliteMessageDraftStore.cs new file mode 100644 index 0000000000..08463e0d3b --- /dev/null +++ b/src/GitHub.App/Services/SqliteMessageDraftStore.cs @@ -0,0 +1,179 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel.Composition; +using System.IO; +using System.Linq; +using System.Threading.Tasks; +using GitHub.Extensions; +using GitHub.Logging; +using Newtonsoft.Json; +using Rothko; +using Serilog; +using SQLite; + +namespace GitHub.Services +{ + /// + /// Stores drafts of messages in an SQL database. + /// + [Export(typeof(IMessageDraftStore))] + [PartCreationPolicy(CreationPolicy.Shared)] + public class SqliteMessageDraftStore : IMessageDraftStore + { + static readonly ILogger log = LogManager.ForContext(); + readonly IOperatingSystem os; + SQLiteAsyncConnection connection; + bool initialized; + + [ImportingConstructor] + public SqliteMessageDraftStore(IOperatingSystem os) + { + this.os = os; + } + + public async Task GetDraft(string key, string secondaryKey) where T : class + { + Guard.ArgumentNotEmptyString(key, nameof(key)); + Guard.ArgumentNotNull(secondaryKey, nameof(secondaryKey)); + + if (await Initialize().ConfigureAwait(false)) + { + try + { + var result = await connection.Table().Where( + x => x.Key == key && x.SecondaryKey == secondaryKey) + .FirstOrDefaultAsync() + .ConfigureAwait(false); + + if (result != null) + { + return JsonConvert.DeserializeObject(result.Data); + } + } + catch (Exception ex) + { + log.Error(ex, "Failed to load message draft into {Type}", typeof(T)); + } + } + + return null; + } + + public async Task> GetDrafts(string key) where T : class + { + Guard.ArgumentNotEmptyString(key, nameof(key)); + + if (await Initialize().ConfigureAwait(false)) + { + try + { + var result = await connection.Table().Where(x => x.Key == key) + .ToListAsync() + .ConfigureAwait(false); + + return result.Select(x => (x.SecondaryKey, JsonConvert.DeserializeObject(x.Data))); + } + catch (Exception ex) + { + log.Error(ex, "Failed to load message drafts into {Type}", typeof(T)); + } + } + + return null; + } + + public async Task UpdateDraft(string key, string secondaryKey, T data) where T : class + { + Guard.ArgumentNotEmptyString(key, nameof(key)); + Guard.ArgumentNotNull(secondaryKey, nameof(secondaryKey)); + + if (!await Initialize().ConfigureAwait(false)) + { + return; + } + + try + { + var row = new Draft + { + Key = key, + SecondaryKey = secondaryKey, + Data = JsonConvert.SerializeObject(data), + }; + + await connection.InsertOrReplaceAsync(row).ConfigureAwait(false); + } + catch (Exception ex) + { + log.Error(ex, "Failed to update message draft"); + } + } + + public async Task DeleteDraft(string key, string secondaryKey) + { + Guard.ArgumentNotEmptyString(key, nameof(key)); + Guard.ArgumentNotNull(secondaryKey, nameof(secondaryKey)); + + if (!await Initialize().ConfigureAwait(false)) + { + return; + } + + try + { + await connection.ExecuteAsync( + "DELETE FROM Drafts WHERE Key=? AND SecondaryKey=?", + key, + secondaryKey).ConfigureAwait(false); + } + catch (Exception ex) + { + log.Error(ex, "Failed to update message draft"); + } + } + + async Task Initialize() + { + if (!initialized) + { + var path = Path.Combine(os.Environment.GetApplicationDataPath(), "drafts.db"); + + try + { + connection = new SQLiteAsyncConnection(path); + + var draftsTable = await connection.GetTableInfoAsync("Drafts").ConfigureAwait(false); + + if (draftsTable.Count == 0) + { + await connection.ExecuteAsync(@" + CREATE TABLE `Drafts` ( + `Key` TEXT, + `SecondaryKey` TEXT, + `Data` TEXT, + UNIQUE(`Key`,`SecondaryKey`) + );").ConfigureAwait(false); + } + } + catch (Exception ex) + { + log.Error(ex, "Error opening drafts from {Path}.", path); + } + finally + { + initialized = true; + } + } + + return connection != null; + } + + [Table("Drafts")] + private class Draft + { + public string Key { get; set; } + public string SecondaryKey { get; set; } + public string Data { get; set; } + } + } +} diff --git a/src/GitHub.App/ViewModels/CommentThreadViewModel.cs b/src/GitHub.App/ViewModels/CommentThreadViewModel.cs index 1fae1484b2..bd5a9541a8 100644 --- a/src/GitHub.App/ViewModels/CommentThreadViewModel.cs +++ b/src/GitHub.App/ViewModels/CommentThreadViewModel.cs @@ -1,7 +1,14 @@ -using System.ComponentModel.Composition; +using System; +using System.Collections.Generic; +using System.ComponentModel.Composition; +using System.Reactive.Concurrency; +using System.Reactive.Linq; +using System.Reactive.Subjects; using System.Threading.Tasks; using GitHub.Extensions; using GitHub.Models; +using GitHub.Models.Drafts; +using GitHub.Services; using ReactiveUI; namespace GitHub.ViewModels @@ -12,24 +19,36 @@ namespace GitHub.ViewModels public abstract class CommentThreadViewModel : ReactiveObject, ICommentThreadViewModel { readonly ReactiveList comments = new ReactiveList(); + readonly Dictionary> draftThrottles = + new Dictionary>(); + readonly IScheduler timerScheduler; /// /// Initializes a new instance of the class. /// + /// The message draft store. [ImportingConstructor] - public CommentThreadViewModel() + public CommentThreadViewModel(IMessageDraftStore draftStore) + : this(draftStore, DefaultScheduler.Instance) { } /// - /// Intializes a new instance of the class. + /// Initializes a new instance of the class. /// - /// The current user. - protected Task InitializeAsync(ActorModel currentUser) + /// The message draft store. + /// + /// The scheduler to use to apply a throttle to message drafts. + /// + [ImportingConstructor] + public CommentThreadViewModel( + IMessageDraftStore draftStore, + IScheduler timerScheduler) { - Guard.ArgumentNotNull(currentUser, nameof(currentUser)); - CurrentUser = new ActorViewModel(currentUser); - return Task.CompletedTask; + Guard.ArgumentNotNull(draftStore, nameof(draftStore)); + + DraftStore = draftStore; + this.timerScheduler = timerScheduler; } /// @@ -41,13 +60,102 @@ protected Task InitializeAsync(ActorModel currentUser) /// IReadOnlyReactiveList ICommentThreadViewModel.Comments => comments; + protected IMessageDraftStore DraftStore { get; } + /// - public abstract Task PostComment(string body); + public abstract Task PostComment(ICommentViewModel comment); /// - public abstract Task EditComment(string id, string body); + public abstract Task EditComment(ICommentViewModel comment); /// - public abstract Task DeleteComment(int pullRequestId, int commentId); + public abstract Task DeleteComment(ICommentViewModel comment); + + /// + /// Adds a placeholder comment that will allow the user to enter a reply, and wires up + /// event listeners for saving drafts. + /// + /// The placeholder comment view model. + /// An object which when disposed will remove the event listeners. + protected IDisposable AddPlaceholder(ICommentViewModel placeholder) + { + Comments.Add(placeholder); + + return placeholder.WhenAnyValue( + x => x.EditState, + x => x.Body, + (state, body) => (state, body)) + .Subscribe(x => PlaceholderChanged(placeholder, x.state, x.body)); + } + + /// + /// Intializes a new instance of the class. + /// + /// The current user. + protected Task InitializeAsync(ActorModel currentUser) + { + Guard.ArgumentNotNull(currentUser, nameof(currentUser)); + CurrentUser = new ActorViewModel(currentUser); + return Task.CompletedTask; + } + + protected virtual CommentDraft BuildDraft(ICommentViewModel comment) + { + return !string.IsNullOrEmpty(comment.Body) ? + new CommentDraft { Body = comment.Body } : + null; + } + + protected async Task DeleteDraft(ICommentViewModel comment) + { + if (draftThrottles.TryGetValue(comment, out var throttle)) + { + throttle.OnCompleted(); + draftThrottles.Remove(comment); + } + + var (key, secondaryKey) = GetDraftKeys(comment); + await DraftStore.DeleteDraft(key, secondaryKey).ConfigureAwait(false); + } + + protected abstract (string key, string secondaryKey) GetDraftKeys(ICommentViewModel comment); + + void PlaceholderChanged(ICommentViewModel placeholder, CommentEditState state, string body) + { + if (state == CommentEditState.Editing) + { + if (!draftThrottles.TryGetValue(placeholder, out var throttle)) + { + var subject = new Subject(); + subject.Throttle(TimeSpan.FromSeconds(1), timerScheduler).Subscribe(UpdateDraft); + draftThrottles.Add(placeholder, subject); + throttle = subject; + } + + throttle.OnNext(placeholder); + } + else if (state != CommentEditState.Editing) + { + DeleteDraft(placeholder).Forget(); + } + } + + void UpdateDraft(ICommentViewModel comment) + { + if (comment.EditState == CommentEditState.Editing) + { + var draft = BuildDraft(comment); + var (key, secondaryKey) = GetDraftKeys(comment); + + if (draft != null) + { + DraftStore.UpdateDraft(key, secondaryKey, draft).Forget(); + } + else + { + DraftStore.DeleteDraft(key, secondaryKey).Forget(); + } + } + } } } diff --git a/src/GitHub.App/ViewModels/CommentViewModel.cs b/src/GitHub.App/ViewModels/CommentViewModel.cs index fade1742f0..259535f0c1 100644 --- a/src/GitHub.App/ViewModels/CommentViewModel.cs +++ b/src/GitHub.App/ViewModels/CommentViewModel.cs @@ -219,7 +219,7 @@ async Task DoDelete() ErrorMessage = null; IsSubmitting = true; - await Thread.DeleteComment(PullRequestId, DatabaseId).ConfigureAwait(true); + await Thread.DeleteComment(this).ConfigureAwait(true); } catch (Exception e) { @@ -264,12 +264,14 @@ async Task DoCommitEdit() if (Id == null) { - await Thread.PostComment(Body).ConfigureAwait(true); + await Thread.PostComment(this).ConfigureAwait(true); } else { - await Thread.EditComment(Id, Body).ConfigureAwait(true); + await Thread.EditComment(this).ConfigureAwait(true); } + + EditState = CommentEditState.None; } catch (Exception e) { diff --git a/src/GitHub.App/ViewModels/GitHubPane/PullRequestCreationViewModel.cs b/src/GitHub.App/ViewModels/GitHubPane/PullRequestCreationViewModel.cs index 8d2547c7a7..f72bc605fe 100644 --- a/src/GitHub.App/ViewModels/GitHubPane/PullRequestCreationViewModel.cs +++ b/src/GitHub.App/ViewModels/GitHubPane/PullRequestCreationViewModel.cs @@ -5,6 +5,7 @@ using System.Globalization; using System.Linq; using System.Reactive; +using System.Reactive.Concurrency; using System.Reactive.Disposables; using System.Reactive.Linq; using System.Threading.Tasks; @@ -14,12 +15,15 @@ using GitHub.Factories; using GitHub.Logging; using GitHub.Models; +using GitHub.Models.Drafts; +using GitHub.Primitives; using GitHub.Services; using GitHub.Validation; using Octokit; using ReactiveUI; using Serilog; using IConnection = GitHub.Models.IConnection; +using static System.FormattableString; namespace GitHub.ViewModels.GitHubPane { @@ -33,6 +37,8 @@ public class PullRequestCreationViewModel : PanePageViewModelBase, IPullRequestC readonly ObservableAsPropertyHelper isExecuting; readonly IPullRequestService service; readonly IModelServiceFactory modelServiceFactory; + readonly IMessageDraftStore draftStore; + readonly IScheduler timerScheduler; readonly CompositeDisposable disposables = new CompositeDisposable(); ILocalRepositoryModel activeLocalRepo; ObservableAsPropertyHelper githubRepository; @@ -42,14 +48,29 @@ public class PullRequestCreationViewModel : PanePageViewModelBase, IPullRequestC public PullRequestCreationViewModel( IModelServiceFactory modelServiceFactory, IPullRequestService service, - INotificationService notifications) + INotificationService notifications, + IMessageDraftStore draftStore) + : this(modelServiceFactory, service, notifications, draftStore, DefaultScheduler.Instance) + { + } + + public PullRequestCreationViewModel( + IModelServiceFactory modelServiceFactory, + IPullRequestService service, + INotificationService notifications, + IMessageDraftStore draftStore, + IScheduler timerScheduler) { Guard.ArgumentNotNull(modelServiceFactory, nameof(modelServiceFactory)); Guard.ArgumentNotNull(service, nameof(service)); Guard.ArgumentNotNull(notifications, nameof(notifications)); + Guard.ArgumentNotNull(draftStore, nameof(draftStore)); + Guard.ArgumentNotNull(timerScheduler, nameof(timerScheduler)); this.service = service; this.modelServiceFactory = modelServiceFactory; + this.draftStore = draftStore; + this.timerScheduler = timerScheduler; this.WhenAnyValue(x => x.Branches) .WhereNotNull() @@ -93,15 +114,22 @@ public PullRequestCreationViewModel( TargetBranch.Repository.CloneUrl.ToRepositoryUrl().Append("pull/" + pr.Number))); NavigateTo("/pulls?refresh=true"); Cancel.Execute(); + draftStore.DeleteDraft(GetDraftKey(), string.Empty).Forget(); + Close(); }); Cancel = ReactiveCommand.Create(() => { }); - Cancel.Subscribe(_ => Close()); + Cancel.Subscribe(_ => + { + Close(); + draftStore.DeleteDraft(GetDraftKey(), string.Empty).Forget(); + }); isExecuting = CreatePullRequest.IsExecuting.ToProperty(this, x => x.IsExecuting); this.WhenAnyValue(x => x.Initialized, x => x.GitHubRepository, x => x.IsExecuting) .Select(x => !(x.Item1 && x.Item2 != null && !x.Item3)) + .ObserveOn(RxApp.MainThreadScheduler) .Subscribe(x => IsBusy = x); } @@ -146,6 +174,39 @@ public async Task InitializeAsync(ILocalRepositoryModel repository, IConnection Initialized = true; }); + var draftKey = GetDraftKey(); + await LoadInitialState(draftKey).ConfigureAwait(true); + + this.WhenAnyValue( + x => x.PRTitle, + x => x.Description, + (t, d) => new PullRequestDraft { Title = t, Body = d }) + .Throttle(TimeSpan.FromSeconds(1), timerScheduler) + .Subscribe(x => draftStore.UpdateDraft(draftKey, string.Empty, x)); + + Initialized = true; + } + + async Task LoadInitialState(string draftKey) + { + if (activeLocalRepo.CloneUrl == null) + return; + + var draft = await draftStore.GetDraft(draftKey, string.Empty).ConfigureAwait(true); + + if (draft != null) + { + PRTitle = draft.Title; + Description = draft.Body; + } + else + { + LoadDescriptionFromCommits(); + } + } + + void LoadDescriptionFromCommits() + { SourceBranch = activeLocalRepo.CurrentBranch; var uniqueCommits = this.WhenAnyValue( @@ -176,7 +237,7 @@ public async Task InitializeAsync(ILocalRepositoryModel repository, IConnection Observable.CombineLatest( this.WhenAnyValue(x => x.SourceBranch), uniqueCommits, - service.GetPullRequestTemplate(repository).DefaultIfEmpty(string.Empty), + service.GetPullRequestTemplate(activeLocalRepo).DefaultIfEmpty(string.Empty), (compare, commits, template) => new { compare, commits, template }) .Subscribe(x => { @@ -203,8 +264,6 @@ public async Task InitializeAsync(ILocalRepositoryModel repository, IConnection PRTitle = prTitle; Description = prDescription; }); - - Initialized = true; } void SetupValidators() @@ -239,6 +298,20 @@ protected override void Dispose(bool disposing) } } + public static string GetDraftKey( + UriString cloneUri, + string branchName) + { + return Invariant($"pr|{cloneUri}|{branchName}"); + } + + protected string GetDraftKey() + { + return GetDraftKey( + activeLocalRepo.CloneUrl, + SourceBranch.Name); + } + public IRemoteRepositoryModel GitHubRepository { get { return githubRepository?.Value; } } bool IsExecuting { get { return isExecuting.Value; } } diff --git a/src/GitHub.App/ViewModels/GitHubPane/PullRequestReviewAuthoringViewModel.cs b/src/GitHub.App/ViewModels/GitHubPane/PullRequestReviewAuthoringViewModel.cs index f26b4b9ab8..fc80765fd4 100644 --- a/src/GitHub.App/ViewModels/GitHubPane/PullRequestReviewAuthoringViewModel.cs +++ b/src/GitHub.App/ViewModels/GitHubPane/PullRequestReviewAuthoringViewModel.cs @@ -3,12 +3,15 @@ using System.ComponentModel.Composition; using System.Linq; using System.Reactive; +using System.Reactive.Concurrency; using System.Reactive.Linq; using System.Threading.Tasks; using GitHub.Extensions; using GitHub.Factories; using GitHub.Logging; using GitHub.Models; +using GitHub.Models.Drafts; +using GitHub.Primitives; using GitHub.Services; using ReactiveUI; using Serilog; @@ -24,7 +27,9 @@ public class PullRequestReviewAuthoringViewModel : PanePageViewModelBase, IPullR readonly IPullRequestEditorService editorService; readonly IPullRequestSessionManager sessionManager; + readonly IMessageDraftStore draftStore; readonly IPullRequestService pullRequestService; + readonly IScheduler timerScheduler; IPullRequestSession session; IDisposable sessionSubscription; PullRequestReviewModel model; @@ -39,15 +44,31 @@ public PullRequestReviewAuthoringViewModel( IPullRequestService pullRequestService, IPullRequestEditorService editorService, IPullRequestSessionManager sessionManager, + IMessageDraftStore draftStore, IPullRequestFilesViewModel files) + : this(pullRequestService, editorService, sessionManager,draftStore, files, DefaultScheduler.Instance) + { + } + + public PullRequestReviewAuthoringViewModel( + IPullRequestService pullRequestService, + IPullRequestEditorService editorService, + IPullRequestSessionManager sessionManager, + IMessageDraftStore draftStore, + IPullRequestFilesViewModel files, + IScheduler timerScheduler) { Guard.ArgumentNotNull(editorService, nameof(editorService)); Guard.ArgumentNotNull(sessionManager, nameof(sessionManager)); + Guard.ArgumentNotNull(draftStore, nameof(draftStore)); Guard.ArgumentNotNull(files, nameof(files)); + Guard.ArgumentNotNull(timerScheduler, nameof(timerScheduler)); this.pullRequestService = pullRequestService; this.editorService = editorService; this.sessionManager = sessionManager; + this.draftStore = draftStore; + this.timerScheduler = timerScheduler; canApproveRequestChanges = this.WhenAnyValue( x => x.Model, @@ -148,8 +169,25 @@ public async Task InitializeAsync( { LocalRepository = localRepository; RemoteRepositoryOwner = owner; - session = await sessionManager.GetSession(owner, repo, pullRequestNumber); - await Load(session.PullRequest); + session = await sessionManager.GetSession(owner, repo, pullRequestNumber).ConfigureAwait(true); + await Load(session.PullRequest).ConfigureAwait(true); + + if (LocalRepository?.CloneUrl != null) + { + var key = GetDraftKey(); + + if (string.IsNullOrEmpty(Body)) + { + var draft = await draftStore.GetDraft(key, string.Empty) + .ConfigureAwait(true); + Body = draft?.Body; + } + + this.WhenAnyValue(x => x.Body) + .Throttle(TimeSpan.FromSeconds(1), timerScheduler) + .Select(x => new PullRequestReviewDraft { Body = x }) + .Subscribe(x => draftStore.UpdateDraft(key, string.Empty, x)); + } } finally { @@ -182,6 +220,20 @@ public override async Task Refresh() } } + public static string GetDraftKey( + UriString cloneUri, + int pullRequestNumber) + { + return Invariant($"pr-review|{cloneUri}|{pullRequestNumber}"); + } + + protected string GetDraftKey() + { + return GetDraftKey( + LocalRepository.CloneUrl.WithOwner(RemoteRepositoryOwner), + PullRequestModel.Number); + } + async Task Load(PullRequestDetailModel pullRequest) { try @@ -252,8 +304,9 @@ async Task DoSubmit(Octokit.PullRequestReviewEvent e) try { - await session.PostReview(Body, e); + await session.PostReview(Body, e).ConfigureAwait(true); Close(); + await draftStore.DeleteDraft(GetDraftKey(), string.Empty).ConfigureAwait(true); } catch (Exception ex) { @@ -285,6 +338,7 @@ async Task DoCancel() Close(); } + await draftStore.DeleteDraft(GetDraftKey(), string.Empty).ConfigureAwait(true); } catch (Exception ex) { diff --git a/src/GitHub.App/ViewModels/PullRequestReviewCommentThreadViewModel.cs b/src/GitHub.App/ViewModels/PullRequestReviewCommentThreadViewModel.cs index db225e9c88..db8bdd0f17 100644 --- a/src/GitHub.App/ViewModels/PullRequestReviewCommentThreadViewModel.cs +++ b/src/GitHub.App/ViewModels/PullRequestReviewCommentThreadViewModel.cs @@ -1,12 +1,17 @@ using System; using System.ComponentModel.Composition; +using System.Globalization; using System.Linq; +using System.Reactive.Linq; using System.Threading.Tasks; using GitHub.Extensions; using GitHub.Factories; using GitHub.Models; +using GitHub.Models.Drafts; +using GitHub.Primitives; using GitHub.Services; using ReactiveUI; +using static System.FormattableString; namespace GitHub.ViewModels { @@ -25,9 +30,13 @@ public class PullRequestReviewCommentThreadViewModel : CommentThreadViewModel, I /// /// Initializes a new instance of the class. /// + /// The message draft store. /// The view model factory. [ImportingConstructor] - public PullRequestReviewCommentThreadViewModel(IViewViewModelFactory factory) + public PullRequestReviewCommentThreadViewModel( + IMessageDraftStore draftStore, + IViewViewModelFactory factory) + : base(draftStore) { Guard.ArgumentNotNull(factory, nameof(factory)); @@ -75,7 +84,7 @@ public async Task InitializeAsync( { Guard.ArgumentNotNull(session, nameof(session)); - await base.InitializeAsync(session.User).ConfigureAwait(false); + await base.InitializeAsync(session.User).ConfigureAwait(true); Session = session; File = file; @@ -97,8 +106,23 @@ await vm.InitializeAsync( if (addPlaceholder) { var vm = factory.CreateViewModel(); - await vm.InitializeAsPlaceholderAsync(session, this, false).ConfigureAwait(false); - Comments.Add(vm); + + await vm.InitializeAsPlaceholderAsync( + session, + this, + review.State == PullRequestReviewState.Pending, + false).ConfigureAwait(true); + + var (key, secondaryKey) = GetDraftKeys(vm); + var draft = await DraftStore.GetDraft(key, secondaryKey).ConfigureAwait(true); + + if (draft?.Side == Side) + { + await vm.BeginEdit.Execute(); + vm.Body = draft.Body; + } + + AddPlaceholder(vm); } } @@ -121,13 +145,22 @@ public async Task InitializeNewAsync( IsNewThread = true; var vm = factory.CreateViewModel(); - await vm.InitializeAsPlaceholderAsync(session, this, isEditing).ConfigureAwait(false); - Comments.Add(vm); + await vm.InitializeAsPlaceholderAsync(session, this, session.HasPendingReview, isEditing).ConfigureAwait(false); + + var (key, secondaryKey) = GetDraftKeys(vm); + var draft = await DraftStore.GetDraft(key, secondaryKey).ConfigureAwait(true); + + if (draft?.Side == side) + { + vm.Body = draft.Body; + } + + AddPlaceholder(vm); } - public override async Task PostComment(string body) + public override async Task PostComment(ICommentViewModel comment) { - Guard.ArgumentNotNull(body, nameof(body)); + Guard.ArgumentNotNull(comment, nameof(comment)); if (IsNewThread) { @@ -145,7 +178,7 @@ public override async Task PostComment(string body) } await Session.PostReviewComment( - body, + comment.Body, File.CommitSha, File.RelativePath.Replace("\\", "/"), File.Diff, @@ -154,21 +187,54 @@ await Session.PostReviewComment( else { var replyId = Comments[0].Id; - await Session.PostReviewComment(body, replyId).ConfigureAwait(false); + await Session.PostReviewComment(comment.Body, replyId).ConfigureAwait(false); } + + await DeleteDraft(comment).ConfigureAwait(false); } - public override async Task EditComment(string id, string body) + public override async Task EditComment(ICommentViewModel comment) { - Guard.ArgumentNotNull(id, nameof(id)); - Guard.ArgumentNotNull(body, nameof(body)); + Guard.ArgumentNotNull(comment, nameof(comment)); + + await Session.EditComment(comment.Id, comment.Body).ConfigureAwait(false); + } - await Session.EditComment(id, body).ConfigureAwait(false); + public override async Task DeleteComment(ICommentViewModel comment) + { + Guard.ArgumentNotNull(comment, nameof(comment)); + + await Session.DeleteComment(comment.PullRequestId, comment.DatabaseId).ConfigureAwait(false); + } + + public static (string key, string secondaryKey) GetDraftKeys( + UriString cloneUri, + int pullRequestNumber, + string relativePath, + int lineNumber) + { + relativePath = relativePath.Replace("\\", "/"); + var key = Invariant($"pr-review-comment|{cloneUri}|{pullRequestNumber}|{relativePath}"); + return (key, lineNumber.ToString(CultureInfo.InvariantCulture)); + } + + protected override CommentDraft BuildDraft(ICommentViewModel comment) + { + return new PullRequestReviewCommentDraft + { + Body = comment.Body, + Side = Side, + UpdatedAt = DateTimeOffset.UtcNow, + }; } - public override async Task DeleteComment(int pullRequestId, int commentId) + protected override (string key, string secondaryKey) GetDraftKeys(ICommentViewModel comment) { - await Session.DeleteComment(pullRequestId, commentId).ConfigureAwait(false); + return GetDraftKeys( + Session.LocalRepository.CloneUrl.WithOwner(Session.RepositoryOwner), + Session.PullRequest.Number, + File.RelativePath, + LineNumber); } } } diff --git a/src/GitHub.App/ViewModels/PullRequestReviewCommentViewModel.cs b/src/GitHub.App/ViewModels/PullRequestReviewCommentViewModel.cs index 8a90eaaae3..dc92301ed0 100644 --- a/src/GitHub.App/ViewModels/PullRequestReviewCommentViewModel.cs +++ b/src/GitHub.App/ViewModels/PullRequestReviewCommentViewModel.cs @@ -67,6 +67,7 @@ public async Task InitializeAsync( public async Task InitializeAsPlaceholderAsync( IPullRequestSession session, ICommentThreadViewModel thread, + bool isPending, bool isEditing) { Guard.ArgumentNotNull(session, nameof(session)); @@ -77,6 +78,7 @@ await InitializeAsync( null, isEditing ? CommentEditState.Editing : CommentEditState.Placeholder).ConfigureAwait(true); this.session = session; + IsPending = isPending; } /// @@ -101,7 +103,7 @@ async Task DoStartReview() try { - await session.StartReview().ConfigureAwait(false); + await session.StartReview().ConfigureAwait(true); await CommitEdit.Execute(); } finally diff --git a/src/GitHub.App/sqlite-net/SQLite.cs b/src/GitHub.App/sqlite-net/SQLite.cs new file mode 100644 index 0000000000..6861b8c0eb --- /dev/null +++ b/src/GitHub.App/sqlite-net/SQLite.cs @@ -0,0 +1,4523 @@ +// +// Copyright (c) 2009-2018 Krueger Systems, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. +// +#if WINDOWS_PHONE && !USE_WP8_NATIVE_SQLITE +#define USE_CSHARP_SQLITE +#endif + +using System; +using System.Collections; +using System.Diagnostics; +#if !USE_SQLITEPCL_RAW +using System.Runtime.InteropServices; +#endif +using System.Collections.Generic; +using System.Reflection; +using System.Linq; +using System.Linq.Expressions; +using System.Text; +using System.Threading; + +#if USE_CSHARP_SQLITE +using Sqlite3 = Community.CsharpSqlite.Sqlite3; +using Sqlite3DatabaseHandle = Community.CsharpSqlite.Sqlite3.sqlite3; +using Sqlite3Statement = Community.CsharpSqlite.Sqlite3.Vdbe; +#elif USE_WP8_NATIVE_SQLITE +using Sqlite3 = Sqlite.Sqlite3; +using Sqlite3DatabaseHandle = Sqlite.Database; +using Sqlite3Statement = Sqlite.Statement; +#elif USE_SQLITEPCL_RAW +using Sqlite3DatabaseHandle = SQLitePCL.sqlite3; +using Sqlite3Statement = SQLitePCL.sqlite3_stmt; +using Sqlite3 = SQLitePCL.raw; +#else +using Sqlite3DatabaseHandle = System.IntPtr; +using Sqlite3Statement = System.IntPtr; +#endif + +#pragma warning disable 1591 // XML Doc Comments + +namespace SQLite +{ + public class SQLiteException : Exception + { + public SQLite3.Result Result { get; private set; } + + protected SQLiteException(SQLite3.Result r, string message) : base(message) + { + Result = r; + } + + public static SQLiteException New(SQLite3.Result r, string message) + { + return new SQLiteException(r, message); + } + } + + public class NotNullConstraintViolationException : SQLiteException + { + public IEnumerable Columns { get; protected set; } + + protected NotNullConstraintViolationException(SQLite3.Result r, string message) + : this(r, message, null, null) + { + + } + + protected NotNullConstraintViolationException(SQLite3.Result r, string message, TableMapping mapping, object obj) + : base(r, message) + { + if (mapping != null && obj != null) + { + this.Columns = from c in mapping.Columns + where c.IsNullable == false && c.GetValue(obj) == null + select c; + } + } + + public static new NotNullConstraintViolationException New(SQLite3.Result r, string message) + { + return new NotNullConstraintViolationException(r, message); + } + + public static NotNullConstraintViolationException New(SQLite3.Result r, string message, TableMapping mapping, object obj) + { + return new NotNullConstraintViolationException(r, message, mapping, obj); + } + + public static NotNullConstraintViolationException New(SQLiteException exception, TableMapping mapping, object obj) + { + return new NotNullConstraintViolationException(exception.Result, exception.Message, mapping, obj); + } + } + + [Flags] + public enum SQLiteOpenFlags + { + ReadOnly = 1, ReadWrite = 2, Create = 4, + NoMutex = 0x8000, FullMutex = 0x10000, + SharedCache = 0x20000, PrivateCache = 0x40000, + ProtectionComplete = 0x00100000, + ProtectionCompleteUnlessOpen = 0x00200000, + ProtectionCompleteUntilFirstUserAuthentication = 0x00300000, + ProtectionNone = 0x00400000 + } + + [Flags] + public enum CreateFlags + { + /// + /// Use the default creation options + /// + None = 0x000, + /// + /// Create a primary key index for a property called 'Id' (case-insensitive). + /// This avoids the need for the [PrimaryKey] attribute. + /// + ImplicitPK = 0x001, + /// + /// Create indices for properties ending in 'Id' (case-insensitive). + /// + ImplicitIndex = 0x002, + /// + /// Create a primary key for a property called 'Id' and + /// create an indices for properties ending in 'Id' (case-insensitive). + /// + AllImplicit = 0x003, + /// + /// Force the primary key property to be auto incrementing. + /// This avoids the need for the [AutoIncrement] attribute. + /// The primary key property on the class should have type int or long. + /// + AutoIncPK = 0x004, + /// + /// Create virtual table using FTS3 + /// + FullTextSearch3 = 0x100, + /// + /// Create virtual table using FTS4 + /// + FullTextSearch4 = 0x200 + } + + /// + /// An open connection to a SQLite database. + /// + [Preserve(AllMembers = true)] + public partial class SQLiteConnection : IDisposable + { + private bool _open; + private TimeSpan _busyTimeout; + readonly static Dictionary _mappings = new Dictionary(); + private System.Diagnostics.Stopwatch _sw; + private long _elapsedMilliseconds = 0; + + private int _transactionDepth = 0; + private Random _rand = new Random(); + + public Sqlite3DatabaseHandle Handle { get; private set; } + static readonly Sqlite3DatabaseHandle NullHandle = default(Sqlite3DatabaseHandle); + + /// + /// Gets the database path used by this connection. + /// + public string DatabasePath { get; private set; } + + /// + /// Gets the SQLite library version number. 3007014 would be v3.7.14 + /// + public int LibVersionNumber { get; private set; } + + /// + /// Whether Trace lines should be written that show the execution time of queries. + /// + public bool TimeExecution { get; set; } + + /// + /// Whether to writer queries to during execution. + /// + /// The tracer. + public bool Trace { get; set; } + + /// + /// The delegate responsible for writing trace lines. + /// + /// The tracer. + public Action Tracer { get; set; } + + /// + /// Whether to store DateTime properties as ticks (true) or strings (false). + /// + public bool StoreDateTimeAsTicks { get; private set; } + +#if USE_SQLITEPCL_RAW && !NO_SQLITEPCL_RAW_BATTERIES + static SQLiteConnection () + { + SQLitePCL.Batteries_V2.Init (); + } +#endif + + /// + /// Constructs a new SQLiteConnection and opens a SQLite database specified by databasePath. + /// + /// + /// Specifies the path to the database file. + /// + /// + /// Specifies whether to store DateTime properties as ticks (true) or strings (false). You + /// absolutely do want to store them as Ticks in all new projects. The value of false is + /// only here for backwards compatibility. There is a *significant* speed advantage, with no + /// down sides, when setting storeDateTimeAsTicks = true. + /// If you use DateTimeOffset properties, it will be always stored as ticks regardingless + /// the storeDateTimeAsTicks parameter. + /// + /// + /// Specifies the encryption key to use on the database. Should be a string or a byte[]. + /// + public SQLiteConnection(string databasePath, bool storeDateTimeAsTicks = true, object key = null) + : this(databasePath, SQLiteOpenFlags.ReadWrite | SQLiteOpenFlags.Create, storeDateTimeAsTicks, key: key) + { + } + + /// + /// Constructs a new SQLiteConnection and opens a SQLite database specified by databasePath. + /// + /// + /// Specifies the path to the database file. + /// + /// + /// Flags controlling how the connection should be opened. + /// + /// + /// Specifies whether to store DateTime properties as ticks (true) or strings (false). You + /// absolutely do want to store them as Ticks in all new projects. The value of false is + /// only here for backwards compatibility. There is a *significant* speed advantage, with no + /// down sides, when setting storeDateTimeAsTicks = true. + /// If you use DateTimeOffset properties, it will be always stored as ticks regardingless + /// the storeDateTimeAsTicks parameter. + /// + /// + /// Specifies the encryption key to use on the database. Should be a string or a byte[]. + /// + public SQLiteConnection(string databasePath, SQLiteOpenFlags openFlags, bool storeDateTimeAsTicks = true, object key = null) + { + if (databasePath == null) + throw new ArgumentException("Must be specified", nameof(databasePath)); + + DatabasePath = databasePath; + + LibVersionNumber = SQLite3.LibVersionNumber(); + +#if NETFX_CORE + SQLite3.SetDirectory(/*temp directory type*/2, Windows.Storage.ApplicationData.Current.TemporaryFolder.Path); +#endif + + Sqlite3DatabaseHandle handle; + +#if SILVERLIGHT || USE_CSHARP_SQLITE || USE_SQLITEPCL_RAW + var r = SQLite3.Open (databasePath, out handle, (int)openFlags, IntPtr.Zero); +#else + // open using the byte[] + // in the case where the path may include Unicode + // force open to using UTF-8 using sqlite3_open_v2 + var databasePathAsBytes = GetNullTerminatedUtf8(DatabasePath); + var r = SQLite3.Open(databasePathAsBytes, out handle, (int)openFlags, IntPtr.Zero); +#endif + + Handle = handle; + if (r != SQLite3.Result.OK) + { + throw SQLiteException.New(r, String.Format("Could not open database file: {0} ({1})", DatabasePath, r)); + } + _open = true; + + StoreDateTimeAsTicks = storeDateTimeAsTicks; + + BusyTimeout = TimeSpan.FromSeconds(0.1); + Tracer = line => Debug.WriteLine(line); + + if (key is string stringKey) + { + SetKey(stringKey); + } + else if (key is byte[] bytesKey) + { + SetKey(bytesKey); + } + else if (key != null) + { + throw new ArgumentException("Encryption keys must be strings or byte arrays", nameof(key)); + } + if (openFlags.HasFlag(SQLiteOpenFlags.ReadWrite)) + { + ExecuteScalar("PRAGMA journal_mode=WAL"); + } + } + + /// + /// Convert an input string to a quoted SQL string that can be safely used in queries. + /// + /// The quoted string. + /// The unsafe string to quote. + static string Quote(string unsafeString) + { + // TODO: Doesn't call sqlite3_mprintf("%Q", u) because we're waiting on https://github.com/ericsink/SQLitePCL.raw/issues/153 + if (unsafeString == null) return "NULL"; + var safe = unsafeString.Replace("'", "''"); + return "'" + safe + "'"; + } + + /// + /// Sets the key used to encrypt/decrypt the database with "pragma key = ...". + /// This must be the first thing you call before doing anything else with this connection + /// if your database is encrypted. + /// This only has an effect if you are using the SQLCipher nuget package. + /// + /// Ecryption key plain text that is converted to the real encryption key using PBKDF2 key derivation + void SetKey(string key) + { + if (key == null) throw new ArgumentNullException(nameof(key)); + var q = Quote(key); + Execute("pragma key = " + q); + } + + /// + /// Sets the key used to encrypt/decrypt the database. + /// This must be the first thing you call before doing anything else with this connection + /// if your database is encrypted. + /// This only has an effect if you are using the SQLCipher nuget package. + /// + /// 256-bit (32 byte) ecryption key data + void SetKey(byte[] key) + { + if (key == null) throw new ArgumentNullException(nameof(key)); + if (key.Length != 32) throw new ArgumentException("Key must be 32 bytes (256-bit)", nameof(key)); + var s = String.Join("", key.Select(x => x.ToString("X2"))); + Execute("pragma key = \"x'" + s + "'\""); + } + + /// + /// Enable or disable extension loading. + /// + public void EnableLoadExtension(bool enabled) + { + SQLite3.Result r = SQLite3.EnableLoadExtension(Handle, enabled ? 1 : 0); + if (r != SQLite3.Result.OK) + { + string msg = SQLite3.GetErrmsg(Handle); + throw SQLiteException.New(r, msg); + } + } + +#if !USE_SQLITEPCL_RAW + static byte[] GetNullTerminatedUtf8(string s) + { + var utf8Length = System.Text.Encoding.UTF8.GetByteCount(s); + var bytes = new byte[utf8Length + 1]; + utf8Length = System.Text.Encoding.UTF8.GetBytes(s, 0, s.Length, bytes, 0); + return bytes; + } +#endif + + /// + /// Sets a busy handler to sleep the specified amount of time when a table is locked. + /// The handler will sleep multiple times until a total time of has accumulated. + /// + public TimeSpan BusyTimeout + { + get { return _busyTimeout; } + set + { + _busyTimeout = value; + if (Handle != NullHandle) + { + SQLite3.BusyTimeout(Handle, (int)_busyTimeout.TotalMilliseconds); + } + } + } + + /// + /// Returns the mappings from types to tables that the connection + /// currently understands. + /// + public IEnumerable TableMappings + { + get + { + lock (_mappings) + { + return new List(_mappings.Values); + } + } + } + + /// + /// Retrieves the mapping that is automatically generated for the given type. + /// + /// + /// The type whose mapping to the database is returned. + /// + /// + /// Optional flags allowing implicit PK and indexes based on naming conventions + /// + /// + /// The mapping represents the schema of the columns of the database and contains + /// methods to set and get properties of objects. + /// + public TableMapping GetMapping(Type type, CreateFlags createFlags = CreateFlags.None) + { + TableMapping map; + var key = type.FullName; + lock (_mappings) + { + if (_mappings.TryGetValue(key, out map)) + { + if (createFlags != CreateFlags.None && createFlags != map.CreateFlags) + { + map = new TableMapping(type, createFlags); + _mappings[key] = map; + } + } + else + { + map = new TableMapping(type, createFlags); + _mappings.Add(key, map); + } + } + return map; + } + + /// + /// Retrieves the mapping that is automatically generated for the given type. + /// + /// + /// Optional flags allowing implicit PK and indexes based on naming conventions + /// + /// + /// The mapping represents the schema of the columns of the database and contains + /// methods to set and get properties of objects. + /// + public TableMapping GetMapping(CreateFlags createFlags = CreateFlags.None) + { + return GetMapping(typeof(T), createFlags); + } + + private struct IndexedColumn + { + public int Order; + public string ColumnName; + } + + private struct IndexInfo + { + public string IndexName; + public string TableName; + public bool Unique; + public List Columns; + } + + /// + /// Executes a "drop table" on the database. This is non-recoverable. + /// + public int DropTable() + { + return DropTable(GetMapping(typeof(T))); + } + + /// + /// Executes a "drop table" on the database. This is non-recoverable. + /// + /// + /// The TableMapping used to identify the table. + /// + public int DropTable(TableMapping map) + { + var query = string.Format("drop table if exists \"{0}\"", map.TableName); + return Execute(query); + } + + /// + /// Executes a "create table if not exists" on the database. It also + /// creates any specified indexes on the columns of the table. It uses + /// a schema automatically generated from the specified type. You can + /// later access this schema by calling GetMapping. + /// + /// + /// Whether the table was created or migrated. + /// + public CreateTableResult CreateTable(CreateFlags createFlags = CreateFlags.None) + { + return CreateTable(typeof(T), createFlags); + } + + /// + /// Executes a "create table if not exists" on the database. It also + /// creates any specified indexes on the columns of the table. It uses + /// a schema automatically generated from the specified type. You can + /// later access this schema by calling GetMapping. + /// + /// Type to reflect to a database table. + /// Optional flags allowing implicit PK and indexes based on naming conventions. + /// + /// Whether the table was created or migrated. + /// + public CreateTableResult CreateTable(Type ty, CreateFlags createFlags = CreateFlags.None) + { + var map = GetMapping(ty, createFlags); + + // Present a nice error if no columns specified + if (map.Columns.Length == 0) + { + throw new Exception(string.Format("Cannot create a table without columns (does '{0}' have public properties?)", ty.FullName)); + } + + // Check if the table exists + var result = CreateTableResult.Created; + var existingCols = GetTableInfo(map.TableName); + + // Create or migrate it + if (existingCols.Count == 0) + { + + // Facilitate virtual tables a.k.a. full-text search. + bool fts3 = (createFlags & CreateFlags.FullTextSearch3) != 0; + bool fts4 = (createFlags & CreateFlags.FullTextSearch4) != 0; + bool fts = fts3 || fts4; + var @virtual = fts ? "virtual " : string.Empty; + var @using = fts3 ? "using fts3 " : fts4 ? "using fts4 " : string.Empty; + + // Build query. + var query = "create " + @virtual + "table if not exists \"" + map.TableName + "\" " + @using + "(\n"; + var decls = map.Columns.Select(p => Orm.SqlDecl(p, StoreDateTimeAsTicks)); + var decl = string.Join(",\n", decls.ToArray()); + query += decl; + query += ")"; + if (map.WithoutRowId) + { + query += " without rowid"; + } + + Execute(query); + } + else + { + result = CreateTableResult.Migrated; + MigrateTable(map, existingCols); + } + + var indexes = new Dictionary(); + foreach (var c in map.Columns) + { + foreach (var i in c.Indices) + { + var iname = i.Name ?? map.TableName + "_" + c.Name; + IndexInfo iinfo; + if (!indexes.TryGetValue(iname, out iinfo)) + { + iinfo = new IndexInfo + { + IndexName = iname, + TableName = map.TableName, + Unique = i.Unique, + Columns = new List() + }; + indexes.Add(iname, iinfo); + } + + if (i.Unique != iinfo.Unique) + throw new Exception("All the columns in an index must have the same value for their Unique property"); + + iinfo.Columns.Add(new IndexedColumn + { + Order = i.Order, + ColumnName = c.Name + }); + } + } + + foreach (var indexName in indexes.Keys) + { + var index = indexes[indexName]; + var columns = index.Columns.OrderBy(i => i.Order).Select(i => i.ColumnName).ToArray(); + CreateIndex(indexName, index.TableName, columns, index.Unique); + } + + return result; + } + + /// + /// Executes a "create table if not exists" on the database for each type. It also + /// creates any specified indexes on the columns of the table. It uses + /// a schema automatically generated from the specified type. You can + /// later access this schema by calling GetMapping. + /// + /// + /// Whether the table was created or migrated for each type. + /// + public CreateTablesResult CreateTables(CreateFlags createFlags = CreateFlags.None) + where T : new() + where T2 : new() + { + return CreateTables(createFlags, typeof(T), typeof(T2)); + } + + /// + /// Executes a "create table if not exists" on the database for each type. It also + /// creates any specified indexes on the columns of the table. It uses + /// a schema automatically generated from the specified type. You can + /// later access this schema by calling GetMapping. + /// + /// + /// Whether the table was created or migrated for each type. + /// + public CreateTablesResult CreateTables(CreateFlags createFlags = CreateFlags.None) + where T : new() + where T2 : new() + where T3 : new() + { + return CreateTables(createFlags, typeof(T), typeof(T2), typeof(T3)); + } + + /// + /// Executes a "create table if not exists" on the database for each type. It also + /// creates any specified indexes on the columns of the table. It uses + /// a schema automatically generated from the specified type. You can + /// later access this schema by calling GetMapping. + /// + /// + /// Whether the table was created or migrated for each type. + /// + public CreateTablesResult CreateTables(CreateFlags createFlags = CreateFlags.None) + where T : new() + where T2 : new() + where T3 : new() + where T4 : new() + { + return CreateTables(createFlags, typeof(T), typeof(T2), typeof(T3), typeof(T4)); + } + + /// + /// Executes a "create table if not exists" on the database for each type. It also + /// creates any specified indexes on the columns of the table. It uses + /// a schema automatically generated from the specified type. You can + /// later access this schema by calling GetMapping. + /// + /// + /// Whether the table was created or migrated for each type. + /// + public CreateTablesResult CreateTables(CreateFlags createFlags = CreateFlags.None) + where T : new() + where T2 : new() + where T3 : new() + where T4 : new() + where T5 : new() + { + return CreateTables(createFlags, typeof(T), typeof(T2), typeof(T3), typeof(T4), typeof(T5)); + } + + /// + /// Executes a "create table if not exists" on the database for each type. It also + /// creates any specified indexes on the columns of the table. It uses + /// a schema automatically generated from the specified type. You can + /// later access this schema by calling GetMapping. + /// + /// + /// Whether the table was created or migrated for each type. + /// + public CreateTablesResult CreateTables(CreateFlags createFlags = CreateFlags.None, params Type[] types) + { + var result = new CreateTablesResult(); + foreach (Type type in types) + { + var aResult = CreateTable(type, createFlags); + result.Results[type] = aResult; + } + return result; + } + + /// + /// Creates an index for the specified table and columns. + /// + /// Name of the index to create + /// Name of the database table + /// An array of column names to index + /// Whether the index should be unique + public int CreateIndex(string indexName, string tableName, string[] columnNames, bool unique = false) + { + const string sqlFormat = "create {2} index if not exists \"{3}\" on \"{0}\"(\"{1}\")"; + var sql = String.Format(sqlFormat, tableName, string.Join("\", \"", columnNames), unique ? "unique" : "", indexName); + return Execute(sql); + } + + /// + /// Creates an index for the specified table and column. + /// + /// Name of the index to create + /// Name of the database table + /// Name of the column to index + /// Whether the index should be unique + public int CreateIndex(string indexName, string tableName, string columnName, bool unique = false) + { + return CreateIndex(indexName, tableName, new string[] { columnName }, unique); + } + + /// + /// Creates an index for the specified table and column. + /// + /// Name of the database table + /// Name of the column to index + /// Whether the index should be unique + public int CreateIndex(string tableName, string columnName, bool unique = false) + { + return CreateIndex(tableName + "_" + columnName, tableName, columnName, unique); + } + + /// + /// Creates an index for the specified table and columns. + /// + /// Name of the database table + /// An array of column names to index + /// Whether the index should be unique + public int CreateIndex(string tableName, string[] columnNames, bool unique = false) + { + return CreateIndex(tableName + "_" + string.Join("_", columnNames), tableName, columnNames, unique); + } + + /// + /// Creates an index for the specified object property. + /// e.g. CreateIndex<Client>(c => c.Name); + /// + /// Type to reflect to a database table. + /// Property to index + /// Whether the index should be unique + public int CreateIndex(Expression> property, bool unique = false) + { + MemberExpression mx; + if (property.Body.NodeType == ExpressionType.Convert) + { + mx = ((UnaryExpression)property.Body).Operand as MemberExpression; + } + else + { + mx = (property.Body as MemberExpression); + } + var propertyInfo = mx.Member as PropertyInfo; + if (propertyInfo == null) + { + throw new ArgumentException("The lambda expression 'property' should point to a valid Property"); + } + + var propName = propertyInfo.Name; + + var map = GetMapping(); + var colName = map.FindColumnWithPropertyName(propName).Name; + + return CreateIndex(map.TableName, colName, unique); + } + + [Preserve(AllMembers = true)] + public class ColumnInfo + { + // public int cid { get; set; } + + [Column("name")] + public string Name { get; set; } + + // [Column ("type")] + // public string ColumnType { get; set; } + + public int notnull { get; set; } + + // public string dflt_value { get; set; } + + // public int pk { get; set; } + + public override string ToString() + { + return Name; + } + } + + /// + /// Query the built-in sqlite table_info table for a specific tables columns. + /// + /// The columns contains in the table. + /// Table name. + public List GetTableInfo(string tableName) + { + var query = "pragma table_info(\"" + tableName + "\")"; + return Query(query); + } + + void MigrateTable(TableMapping map, List existingCols) + { + var toBeAdded = new List(); + + foreach (var p in map.Columns) + { + var found = false; + foreach (var c in existingCols) + { + found = (string.Compare(p.Name, c.Name, StringComparison.OrdinalIgnoreCase) == 0); + if (found) + break; + } + if (!found) + { + toBeAdded.Add(p); + } + } + + foreach (var p in toBeAdded) + { + var addCol = "alter table \"" + map.TableName + "\" add column " + Orm.SqlDecl(p, StoreDateTimeAsTicks); + Execute(addCol); + } + } + + /// + /// Creates a new SQLiteCommand. Can be overridden to provide a sub-class. + /// + /// + protected virtual SQLiteCommand NewCommand() + { + return new SQLiteCommand(this); + } + + /// + /// Creates a new SQLiteCommand given the command text with arguments. Place a '?' + /// in the command text for each of the arguments. + /// + /// + /// The fully escaped SQL. + /// + /// + /// Arguments to substitute for the occurences of '?' in the command text. + /// + /// + /// A + /// + public SQLiteCommand CreateCommand(string cmdText, params object[] ps) + { + if (!_open) + throw SQLiteException.New(SQLite3.Result.Error, "Cannot create commands from unopened database"); + + var cmd = NewCommand(); + cmd.CommandText = cmdText; + foreach (var o in ps) + { + cmd.Bind(o); + } + return cmd; + } + + /// + /// Creates a SQLiteCommand given the command text (SQL) with arguments. Place a '?' + /// in the command text for each of the arguments and then executes that command. + /// Use this method instead of Query when you don't expect rows back. Such cases include + /// INSERTs, UPDATEs, and DELETEs. + /// You can set the Trace or TimeExecution properties of the connection + /// to profile execution. + /// + /// + /// The fully escaped SQL. + /// + /// + /// Arguments to substitute for the occurences of '?' in the query. + /// + /// + /// The number of rows modified in the database as a result of this execution. + /// + public int Execute(string query, params object[] args) + { + var cmd = CreateCommand(query, args); + + if (TimeExecution) + { + if (_sw == null) + { + _sw = new Stopwatch(); + } + _sw.Reset(); + _sw.Start(); + } + + var r = cmd.ExecuteNonQuery(); + + if (TimeExecution) + { + _sw.Stop(); + _elapsedMilliseconds += _sw.ElapsedMilliseconds; + Tracer?.Invoke(string.Format("Finished in {0} ms ({1:0.0} s total)", _sw.ElapsedMilliseconds, _elapsedMilliseconds / 1000.0)); + } + + return r; + } + + /// + /// Creates a SQLiteCommand given the command text (SQL) with arguments. Place a '?' + /// in the command text for each of the arguments and then executes that command. + /// Use this method when return primitive values. + /// You can set the Trace or TimeExecution properties of the connection + /// to profile execution. + /// + /// + /// The fully escaped SQL. + /// + /// + /// Arguments to substitute for the occurences of '?' in the query. + /// + /// + /// The number of rows modified in the database as a result of this execution. + /// + public T ExecuteScalar(string query, params object[] args) + { + var cmd = CreateCommand(query, args); + + if (TimeExecution) + { + if (_sw == null) + { + _sw = new Stopwatch(); + } + _sw.Reset(); + _sw.Start(); + } + + var r = cmd.ExecuteScalar(); + + if (TimeExecution) + { + _sw.Stop(); + _elapsedMilliseconds += _sw.ElapsedMilliseconds; + Tracer?.Invoke(string.Format("Finished in {0} ms ({1:0.0} s total)", _sw.ElapsedMilliseconds, _elapsedMilliseconds / 1000.0)); + } + + return r; + } + + /// + /// Creates a SQLiteCommand given the command text (SQL) with arguments. Place a '?' + /// in the command text for each of the arguments and then executes that command. + /// It returns each row of the result using the mapping automatically generated for + /// the given type. + /// + /// + /// The fully escaped SQL. + /// + /// + /// Arguments to substitute for the occurences of '?' in the query. + /// + /// + /// An enumerable with one result for each row returned by the query. + /// + public List Query(string query, params object[] args) where T : new() + { + var cmd = CreateCommand(query, args); + return cmd.ExecuteQuery(); + } + + /// + /// Creates a SQLiteCommand given the command text (SQL) with arguments. Place a '?' + /// in the command text for each of the arguments and then executes that command. + /// It returns each row of the result using the mapping automatically generated for + /// the given type. + /// + /// + /// The fully escaped SQL. + /// + /// + /// Arguments to substitute for the occurences of '?' in the query. + /// + /// + /// An enumerable with one result for each row returned by the query. + /// The enumerator (retrieved by calling GetEnumerator() on the result of this method) + /// will call sqlite3_step on each call to MoveNext, so the database + /// connection must remain open for the lifetime of the enumerator. + /// + public IEnumerable DeferredQuery(string query, params object[] args) where T : new() + { + var cmd = CreateCommand(query, args); + return cmd.ExecuteDeferredQuery(); + } + + /// + /// Creates a SQLiteCommand given the command text (SQL) with arguments. Place a '?' + /// in the command text for each of the arguments and then executes that command. + /// It returns each row of the result using the specified mapping. This function is + /// only used by libraries in order to query the database via introspection. It is + /// normally not used. + /// + /// + /// A to use to convert the resulting rows + /// into objects. + /// + /// + /// The fully escaped SQL. + /// + /// + /// Arguments to substitute for the occurences of '?' in the query. + /// + /// + /// An enumerable with one result for each row returned by the query. + /// + public List Query(TableMapping map, string query, params object[] args) + { + var cmd = CreateCommand(query, args); + return cmd.ExecuteQuery(map); + } + + /// + /// Creates a SQLiteCommand given the command text (SQL) with arguments. Place a '?' + /// in the command text for each of the arguments and then executes that command. + /// It returns each row of the result using the specified mapping. This function is + /// only used by libraries in order to query the database via introspection. It is + /// normally not used. + /// + /// + /// A to use to convert the resulting rows + /// into objects. + /// + /// + /// The fully escaped SQL. + /// + /// + /// Arguments to substitute for the occurences of '?' in the query. + /// + /// + /// An enumerable with one result for each row returned by the query. + /// The enumerator (retrieved by calling GetEnumerator() on the result of this method) + /// will call sqlite3_step on each call to MoveNext, so the database + /// connection must remain open for the lifetime of the enumerator. + /// + public IEnumerable DeferredQuery(TableMapping map, string query, params object[] args) + { + var cmd = CreateCommand(query, args); + return cmd.ExecuteDeferredQuery(map); + } + + /// + /// Returns a queryable interface to the table represented by the given type. + /// + /// + /// A queryable object that is able to translate Where, OrderBy, and Take + /// queries into native SQL. + /// + public TableQuery Table() where T : new() + { + return new TableQuery(this); + } + + /// + /// Attempts to retrieve an object with the given primary key from the table + /// associated with the specified type. Use of this method requires that + /// the given type have a designated PrimaryKey (using the PrimaryKeyAttribute). + /// + /// + /// The primary key. + /// + /// + /// The object with the given primary key. Throws a not found exception + /// if the object is not found. + /// + public T Get(object pk) where T : new() + { + var map = GetMapping(typeof(T)); + return Query(map.GetByPrimaryKeySql, pk).First(); + } + + /// + /// Attempts to retrieve an object with the given primary key from the table + /// associated with the specified type. Use of this method requires that + /// the given type have a designated PrimaryKey (using the PrimaryKeyAttribute). + /// + /// + /// The primary key. + /// + /// + /// The TableMapping used to identify the table. + /// + /// + /// The object with the given primary key. Throws a not found exception + /// if the object is not found. + /// + public object Get(object pk, TableMapping map) + { + return Query(map, map.GetByPrimaryKeySql, pk).First(); + } + + /// + /// Attempts to retrieve the first object that matches the predicate from the table + /// associated with the specified type. + /// + /// + /// A predicate for which object to find. + /// + /// + /// The object that matches the given predicate. Throws a not found exception + /// if the object is not found. + /// + public T Get(Expression> predicate) where T : new() + { + return Table().Where(predicate).First(); + } + + /// + /// Attempts to retrieve an object with the given primary key from the table + /// associated with the specified type. Use of this method requires that + /// the given type have a designated PrimaryKey (using the PrimaryKeyAttribute). + /// + /// + /// The primary key. + /// + /// + /// The object with the given primary key or null + /// if the object is not found. + /// + public T Find(object pk) where T : new() + { + var map = GetMapping(typeof(T)); + return Query(map.GetByPrimaryKeySql, pk).FirstOrDefault(); + } + + /// + /// Attempts to retrieve an object with the given primary key from the table + /// associated with the specified type. Use of this method requires that + /// the given type have a designated PrimaryKey (using the PrimaryKeyAttribute). + /// + /// + /// The primary key. + /// + /// + /// The TableMapping used to identify the table. + /// + /// + /// The object with the given primary key or null + /// if the object is not found. + /// + public object Find(object pk, TableMapping map) + { + return Query(map, map.GetByPrimaryKeySql, pk).FirstOrDefault(); + } + + /// + /// Attempts to retrieve the first object that matches the predicate from the table + /// associated with the specified type. + /// + /// + /// A predicate for which object to find. + /// + /// + /// The object that matches the given predicate or null + /// if the object is not found. + /// + public T Find(Expression> predicate) where T : new() + { + return Table().Where(predicate).FirstOrDefault(); + } + + /// + /// Attempts to retrieve the first object that matches the query from the table + /// associated with the specified type. + /// + /// + /// The fully escaped SQL. + /// + /// + /// Arguments to substitute for the occurences of '?' in the query. + /// + /// + /// The object that matches the given predicate or null + /// if the object is not found. + /// + public T FindWithQuery(string query, params object[] args) where T : new() + { + return Query(query, args).FirstOrDefault(); + } + + /// + /// Attempts to retrieve the first object that matches the query from the table + /// associated with the specified type. + /// + /// + /// The TableMapping used to identify the table. + /// + /// + /// The fully escaped SQL. + /// + /// + /// Arguments to substitute for the occurences of '?' in the query. + /// + /// + /// The object that matches the given predicate or null + /// if the object is not found. + /// + public object FindWithQuery(TableMapping map, string query, params object[] args) + { + return Query(map, query, args).FirstOrDefault(); + } + + /// + /// Whether has been called and the database is waiting for a . + /// + public bool IsInTransaction + { + get { return _transactionDepth > 0; } + } + + /// + /// Begins a new transaction. Call to end the transaction. + /// + /// Throws if a transaction has already begun. + public void BeginTransaction() + { + // The BEGIN command only works if the transaction stack is empty, + // or in other words if there are no pending transactions. + // If the transaction stack is not empty when the BEGIN command is invoked, + // then the command fails with an error. + // Rather than crash with an error, we will just ignore calls to BeginTransaction + // that would result in an error. + if (Interlocked.CompareExchange(ref _transactionDepth, 1, 0) == 0) + { + try + { + Execute("begin transaction"); + } + catch (Exception ex) + { + var sqlExp = ex as SQLiteException; + if (sqlExp != null) + { + // It is recommended that applications respond to the errors listed below + // by explicitly issuing a ROLLBACK command. + // TODO: This rollback failsafe should be localized to all throw sites. + switch (sqlExp.Result) + { + case SQLite3.Result.IOError: + case SQLite3.Result.Full: + case SQLite3.Result.Busy: + case SQLite3.Result.NoMem: + case SQLite3.Result.Interrupt: + RollbackTo(null, true); + break; + } + } + else + { + // Call decrement and not VolatileWrite in case we've already + // created a transaction point in SaveTransactionPoint since the catch. + Interlocked.Decrement(ref _transactionDepth); + } + + throw; + } + } + else + { + // Calling BeginTransaction on an already open transaction is invalid + throw new InvalidOperationException("Cannot begin a transaction while already in a transaction."); + } + } + + /// + /// Creates a savepoint in the database at the current point in the transaction timeline. + /// Begins a new transaction if one is not in progress. + /// + /// Call to undo transactions since the returned savepoint. + /// Call to commit transactions after the savepoint returned here. + /// Call to end the transaction, committing all changes. + /// + /// A string naming the savepoint. + public string SaveTransactionPoint() + { + int depth = Interlocked.Increment(ref _transactionDepth) - 1; + string retVal = "S" + _rand.Next(short.MaxValue) + "D" + depth; + + try + { + Execute("savepoint " + retVal); + } + catch (Exception ex) + { + var sqlExp = ex as SQLiteException; + if (sqlExp != null) + { + // It is recommended that applications respond to the errors listed below + // by explicitly issuing a ROLLBACK command. + // TODO: This rollback failsafe should be localized to all throw sites. + switch (sqlExp.Result) + { + case SQLite3.Result.IOError: + case SQLite3.Result.Full: + case SQLite3.Result.Busy: + case SQLite3.Result.NoMem: + case SQLite3.Result.Interrupt: + RollbackTo(null, true); + break; + } + } + else + { + Interlocked.Decrement(ref _transactionDepth); + } + + throw; + } + + return retVal; + } + + /// + /// Rolls back the transaction that was begun by or . + /// + public void Rollback() + { + RollbackTo(null, false); + } + + /// + /// Rolls back the savepoint created by or SaveTransactionPoint. + /// + /// The name of the savepoint to roll back to, as returned by . If savepoint is null or empty, this method is equivalent to a call to + public void RollbackTo(string savepoint) + { + RollbackTo(savepoint, false); + } + + /// + /// Rolls back the transaction that was begun by . + /// + /// The name of the savepoint to roll back to, as returned by . If savepoint is null or empty, this method is equivalent to a call to + /// true to avoid throwing exceptions, false otherwise + void RollbackTo(string savepoint, bool noThrow) + { + // Rolling back without a TO clause rolls backs all transactions + // and leaves the transaction stack empty. + try + { + if (String.IsNullOrEmpty(savepoint)) + { + if (Interlocked.Exchange(ref _transactionDepth, 0) > 0) + { + Execute("rollback"); + } + } + else + { + DoSavePointExecute(savepoint, "rollback to "); + } + } + catch (SQLiteException) + { + if (!noThrow) + throw; + + } + // No need to rollback if there are no transactions open. + } + + /// + /// Releases a savepoint returned from . Releasing a savepoint + /// makes changes since that savepoint permanent if the savepoint began the transaction, + /// or otherwise the changes are permanent pending a call to . + /// + /// The RELEASE command is like a COMMIT for a SAVEPOINT. + /// + /// The name of the savepoint to release. The string should be the result of a call to + public void Release(string savepoint) + { + try + { + DoSavePointExecute(savepoint, "release "); + } + catch (SQLiteException ex) + { + if (ex.Result == SQLite3.Result.Busy) + { + // Force a rollback since most people don't know this function can fail + // Don't call Rollback() since the _transactionDepth is 0 and it won't try + // Calling rollback makes our _transactionDepth variable correct. + // Writes to the database only happen at depth=0, so this failure will only happen then. + try + { + Execute("rollback"); + } + catch + { + // rollback can fail in all sorts of wonderful version-dependent ways. Let's just hope for the best + } + } + throw; + } + } + + void DoSavePointExecute(string savepoint, string cmd) + { + // Validate the savepoint + int firstLen = savepoint.IndexOf('D'); + if (firstLen >= 2 && savepoint.Length > firstLen + 1) + { + int depth; + if (Int32.TryParse(savepoint.Substring(firstLen + 1), out depth)) + { + // TODO: Mild race here, but inescapable without locking almost everywhere. + if (0 <= depth && depth < _transactionDepth) + { +#if NETFX_CORE || USE_SQLITEPCL_RAW || NETCORE + Volatile.Write (ref _transactionDepth, depth); +#elif SILVERLIGHT + _transactionDepth = depth; +#else + Thread.VolatileWrite(ref _transactionDepth, depth); +#endif + Execute(cmd + savepoint); + return; + } + } + } + + throw new ArgumentException("savePoint is not valid, and should be the result of a call to SaveTransactionPoint.", "savePoint"); + } + + /// + /// Commits the transaction that was begun by . + /// + public void Commit() + { + if (Interlocked.Exchange(ref _transactionDepth, 0) != 0) + { + try + { + Execute("commit"); + } + catch + { + // Force a rollback since most people don't know this function can fail + // Don't call Rollback() since the _transactionDepth is 0 and it won't try + // Calling rollback makes our _transactionDepth variable correct. + try + { + Execute("rollback"); + } + catch + { + // rollback can fail in all sorts of wonderful version-dependent ways. Let's just hope for the best + } + throw; + } + } + // Do nothing on a commit with no open transaction + } + + /// + /// Executes within a (possibly nested) transaction by wrapping it in a SAVEPOINT. If an + /// exception occurs the whole transaction is rolled back, not just the current savepoint. The exception + /// is rethrown. + /// + /// + /// The to perform within a transaction. can contain any number + /// of operations on the connection but should never call or + /// . + /// + public void RunInTransaction(Action action) + { + try + { + var savePoint = SaveTransactionPoint(); + action(); + Release(savePoint); + } + catch (Exception) + { + Rollback(); + throw; + } + } + + /// + /// Inserts all specified objects. + /// + /// + /// An of the objects to insert. + /// + /// A boolean indicating if the inserts should be wrapped in a transaction. + /// + /// + /// The number of rows added to the table. + /// + public int InsertAll(System.Collections.IEnumerable objects, bool runInTransaction = true) + { + var c = 0; + if (runInTransaction) + { + RunInTransaction(() => { + foreach (var r in objects) + { + c += Insert(r); + } + }); + } + else + { + foreach (var r in objects) + { + c += Insert(r); + } + } + return c; + } + + /// + /// Inserts all specified objects. + /// + /// + /// An of the objects to insert. + /// + /// + /// Literal SQL code that gets placed into the command. INSERT {extra} INTO ... + /// + /// + /// A boolean indicating if the inserts should be wrapped in a transaction. + /// + /// + /// The number of rows added to the table. + /// + public int InsertAll(System.Collections.IEnumerable objects, string extra, bool runInTransaction = true) + { + var c = 0; + if (runInTransaction) + { + RunInTransaction(() => { + foreach (var r in objects) + { + c += Insert(r, extra); + } + }); + } + else + { + foreach (var r in objects) + { + c += Insert(r); + } + } + return c; + } + + /// + /// Inserts all specified objects. + /// + /// + /// An of the objects to insert. + /// + /// + /// The type of object to insert. + /// + /// + /// A boolean indicating if the inserts should be wrapped in a transaction. + /// + /// + /// The number of rows added to the table. + /// + public int InsertAll(System.Collections.IEnumerable objects, Type objType, bool runInTransaction = true) + { + var c = 0; + if (runInTransaction) + { + RunInTransaction(() => { + foreach (var r in objects) + { + c += Insert(r, objType); + } + }); + } + else + { + foreach (var r in objects) + { + c += Insert(r, objType); + } + } + return c; + } + + /// + /// Inserts the given object (and updates its + /// auto incremented primary key if it has one). + /// The return value is the number of rows added to the table. + /// + /// + /// The object to insert. + /// + /// + /// The number of rows added to the table. + /// + public int Insert(object obj) + { + if (obj == null) + { + return 0; + } + return Insert(obj, "", Orm.GetType(obj)); + } + + /// + /// Inserts the given object (and updates its + /// auto incremented primary key if it has one). + /// The return value is the number of rows added to the table. + /// If a UNIQUE constraint violation occurs with + /// some pre-existing object, this function deletes + /// the old object. + /// + /// + /// The object to insert. + /// + /// + /// The number of rows modified. + /// + public int InsertOrReplace(object obj) + { + if (obj == null) + { + return 0; + } + return Insert(obj, "OR REPLACE", Orm.GetType(obj)); + } + + /// + /// Inserts the given object (and updates its + /// auto incremented primary key if it has one). + /// The return value is the number of rows added to the table. + /// + /// + /// The object to insert. + /// + /// + /// The type of object to insert. + /// + /// + /// The number of rows added to the table. + /// + public int Insert(object obj, Type objType) + { + return Insert(obj, "", objType); + } + + /// + /// Inserts the given object (and updates its + /// auto incremented primary key if it has one). + /// The return value is the number of rows added to the table. + /// If a UNIQUE constraint violation occurs with + /// some pre-existing object, this function deletes + /// the old object. + /// + /// + /// The object to insert. + /// + /// + /// The type of object to insert. + /// + /// + /// The number of rows modified. + /// + public int InsertOrReplace(object obj, Type objType) + { + return Insert(obj, "OR REPLACE", objType); + } + + /// + /// Inserts the given object (and updates its + /// auto incremented primary key if it has one). + /// The return value is the number of rows added to the table. + /// + /// + /// The object to insert. + /// + /// + /// Literal SQL code that gets placed into the command. INSERT {extra} INTO ... + /// + /// + /// The number of rows added to the table. + /// + public int Insert(object obj, string extra) + { + if (obj == null) + { + return 0; + } + return Insert(obj, extra, Orm.GetType(obj)); + } + + /// + /// Inserts the given object (and updates its + /// auto incremented primary key if it has one). + /// The return value is the number of rows added to the table. + /// + /// + /// The object to insert. + /// + /// + /// Literal SQL code that gets placed into the command. INSERT {extra} INTO ... + /// + /// + /// The type of object to insert. + /// + /// + /// The number of rows added to the table. + /// + public int Insert(object obj, string extra, Type objType) + { + if (obj == null || objType == null) + { + return 0; + } + + var map = GetMapping(objType); + + if (map.PK != null && map.PK.IsAutoGuid) + { + if (map.PK.GetValue(obj).Equals(Guid.Empty)) + { + map.PK.SetValue(obj, Guid.NewGuid()); + } + } + + var replacing = string.Compare(extra, "OR REPLACE", StringComparison.OrdinalIgnoreCase) == 0; + + var cols = replacing ? map.InsertOrReplaceColumns : map.InsertColumns; + var vals = new object[cols.Length]; + for (var i = 0; i < vals.Length; i++) + { + vals[i] = cols[i].GetValue(obj); + } + + var insertCmd = GetInsertCommand(map, extra); + int count; + + lock (insertCmd) + { + // We lock here to protect the prepared statement returned via GetInsertCommand. + // A SQLite prepared statement can be bound for only one operation at a time. + try + { + count = insertCmd.ExecuteNonQuery(vals); + } + catch (SQLiteException ex) + { + if (SQLite3.ExtendedErrCode(this.Handle) == SQLite3.ExtendedResult.ConstraintNotNull) + { + throw NotNullConstraintViolationException.New(ex.Result, ex.Message, map, obj); + } + throw; + } + + if (map.HasAutoIncPK) + { + var id = SQLite3.LastInsertRowid(Handle); + map.SetAutoIncPK(obj, id); + } + } + if (count > 0) + OnTableChanged(map, NotifyTableChangedAction.Insert); + + return count; + } + + readonly Dictionary, PreparedSqlLiteInsertCommand> _insertCommandMap = new Dictionary, PreparedSqlLiteInsertCommand>(); + + PreparedSqlLiteInsertCommand GetInsertCommand(TableMapping map, string extra) + { + PreparedSqlLiteInsertCommand prepCmd; + + var key = Tuple.Create(map.MappedType.FullName, extra); + + lock (_insertCommandMap) + { + _insertCommandMap.TryGetValue(key, out prepCmd); + } + + if (prepCmd == null) + { + prepCmd = CreateInsertCommand(map, extra); + var added = false; + lock (_insertCommandMap) + { + if (!_insertCommandMap.ContainsKey(key)) + { + _insertCommandMap.Add(key, prepCmd); + added = true; + } + } + if (!added) + { + prepCmd.Dispose(); + } + } + + return prepCmd; + } + + PreparedSqlLiteInsertCommand CreateInsertCommand(TableMapping map, string extra) + { + var cols = map.InsertColumns; + string insertSql; + if (cols.Length == 0 && map.Columns.Length == 1 && map.Columns[0].IsAutoInc) + { + insertSql = string.Format("insert {1} into \"{0}\" default values", map.TableName, extra); + } + else + { + var replacing = string.Compare(extra, "OR REPLACE", StringComparison.OrdinalIgnoreCase) == 0; + + if (replacing) + { + cols = map.InsertOrReplaceColumns; + } + + insertSql = string.Format("insert {3} into \"{0}\"({1}) values ({2})", map.TableName, + string.Join(",", (from c in cols + select "\"" + c.Name + "\"").ToArray()), + string.Join(",", (from c in cols + select "?").ToArray()), extra); + + } + + var insertCommand = new PreparedSqlLiteInsertCommand(this, insertSql); + return insertCommand; + } + + /// + /// Updates all of the columns of a table using the specified object + /// except for its primary key. + /// The object is required to have a primary key. + /// + /// + /// The object to update. It must have a primary key designated using the PrimaryKeyAttribute. + /// + /// + /// The number of rows updated. + /// + public int Update(object obj) + { + if (obj == null) + { + return 0; + } + return Update(obj, Orm.GetType(obj)); + } + + /// + /// Updates all of the columns of a table using the specified object + /// except for its primary key. + /// The object is required to have a primary key. + /// + /// + /// The object to update. It must have a primary key designated using the PrimaryKeyAttribute. + /// + /// + /// The type of object to insert. + /// + /// + /// The number of rows updated. + /// + public int Update(object obj, Type objType) + { + int rowsAffected = 0; + if (obj == null || objType == null) + { + return 0; + } + + var map = GetMapping(objType); + + var pk = map.PK; + + if (pk == null) + { + throw new NotSupportedException("Cannot update " + map.TableName + ": it has no PK"); + } + + var cols = from p in map.Columns + where p != pk + select p; + var vals = from c in cols + select c.GetValue(obj); + var ps = new List(vals); + if (ps.Count == 0) + { + // There is a PK but no accompanying data, + // so reset the PK to make the UPDATE work. + cols = map.Columns; + vals = from c in cols + select c.GetValue(obj); + ps = new List(vals); + } + ps.Add(pk.GetValue(obj)); + var q = string.Format("update \"{0}\" set {1} where {2} = ? ", map.TableName, string.Join(",", (from c in cols + select "\"" + c.Name + "\" = ? ").ToArray()), pk.Name); + + try + { + rowsAffected = Execute(q, ps.ToArray()); + } + catch (SQLiteException ex) + { + + if (ex.Result == SQLite3.Result.Constraint && SQLite3.ExtendedErrCode(this.Handle) == SQLite3.ExtendedResult.ConstraintNotNull) + { + throw NotNullConstraintViolationException.New(ex, map, obj); + } + + throw ex; + } + + if (rowsAffected > 0) + OnTableChanged(map, NotifyTableChangedAction.Update); + + return rowsAffected; + } + + /// + /// Updates all specified objects. + /// + /// + /// An of the objects to insert. + /// + /// + /// A boolean indicating if the inserts should be wrapped in a transaction + /// + /// + /// The number of rows modified. + /// + public int UpdateAll(System.Collections.IEnumerable objects, bool runInTransaction = true) + { + var c = 0; + if (runInTransaction) + { + RunInTransaction(() => { + foreach (var r in objects) + { + c += Update(r); + } + }); + } + else + { + foreach (var r in objects) + { + c += Update(r); + } + } + return c; + } + + /// + /// Deletes the given object from the database using its primary key. + /// + /// + /// The object to delete. It must have a primary key designated using the PrimaryKeyAttribute. + /// + /// + /// The number of rows deleted. + /// + public int Delete(object objectToDelete) + { + var map = GetMapping(Orm.GetType(objectToDelete)); + var pk = map.PK; + if (pk == null) + { + throw new NotSupportedException("Cannot delete " + map.TableName + ": it has no PK"); + } + var q = string.Format("delete from \"{0}\" where \"{1}\" = ?", map.TableName, pk.Name); + var count = Execute(q, pk.GetValue(objectToDelete)); + if (count > 0) + OnTableChanged(map, NotifyTableChangedAction.Delete); + return count; + } + + /// + /// Deletes the object with the specified primary key. + /// + /// + /// The primary key of the object to delete. + /// + /// + /// The number of objects deleted. + /// + /// + /// The type of object. + /// + public int Delete(object primaryKey) + { + return Delete(primaryKey, GetMapping(typeof(T))); + } + + /// + /// Deletes the object with the specified primary key. + /// + /// + /// The primary key of the object to delete. + /// + /// + /// The TableMapping used to identify the table. + /// + /// + /// The number of objects deleted. + /// + public int Delete(object primaryKey, TableMapping map) + { + var pk = map.PK; + if (pk == null) + { + throw new NotSupportedException("Cannot delete " + map.TableName + ": it has no PK"); + } + var q = string.Format("delete from \"{0}\" where \"{1}\" = ?", map.TableName, pk.Name); + var count = Execute(q, primaryKey); + if (count > 0) + OnTableChanged(map, NotifyTableChangedAction.Delete); + return count; + } + + /// + /// Deletes all the objects from the specified table. + /// WARNING WARNING: Let me repeat. It deletes ALL the objects from the + /// specified table. Do you really want to do that? + /// + /// + /// The number of objects deleted. + /// + /// + /// The type of objects to delete. + /// + public int DeleteAll() + { + var map = GetMapping(typeof(T)); + return DeleteAll(map); + } + + /// + /// Deletes all the objects from the specified table. + /// WARNING WARNING: Let me repeat. It deletes ALL the objects from the + /// specified table. Do you really want to do that? + /// + /// + /// The TableMapping used to identify the table. + /// + /// + /// The number of objects deleted. + /// + public int DeleteAll(TableMapping map) + { + var query = string.Format("delete from \"{0}\"", map.TableName); + var count = Execute(query); + if (count > 0) + OnTableChanged(map, NotifyTableChangedAction.Delete); + return count; + } + + ~SQLiteConnection() + { + Dispose(false); + } + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + public void Close() + { + Dispose(true); + } + + protected virtual void Dispose(bool disposing) + { + var useClose2 = LibVersionNumber >= 3007014; + + if (_open && Handle != NullHandle) + { + try + { + if (disposing) + { + lock (_insertCommandMap) + { + foreach (var sqlInsertCommand in _insertCommandMap.Values) + { + sqlInsertCommand.Dispose(); + } + _insertCommandMap.Clear(); + } + + var r = useClose2 ? SQLite3.Close2(Handle) : SQLite3.Close(Handle); + if (r != SQLite3.Result.OK) + { + string msg = SQLite3.GetErrmsg(Handle); + throw SQLiteException.New(r, msg); + } + } + else + { + var r = useClose2 ? SQLite3.Close2(Handle) : SQLite3.Close(Handle); + } + } + finally + { + Handle = NullHandle; + _open = false; + } + } + } + + void OnTableChanged(TableMapping table, NotifyTableChangedAction action) + { + var ev = TableChanged; + if (ev != null) + ev(this, new NotifyTableChangedEventArgs(table, action)); + } + + public event EventHandler TableChanged; + } + + public class NotifyTableChangedEventArgs : EventArgs + { + public TableMapping Table { get; private set; } + public NotifyTableChangedAction Action { get; private set; } + + public NotifyTableChangedEventArgs(TableMapping table, NotifyTableChangedAction action) + { + Table = table; + Action = action; + } + } + + public enum NotifyTableChangedAction + { + Insert, + Update, + Delete, + } + + /// + /// Represents a parsed connection string. + /// + public class SQLiteConnectionString + { + public string ConnectionString { get; private set; } + public string DatabasePath { get; private set; } + public bool StoreDateTimeAsTicks { get; private set; } + public object Key { get; private set; } + +#if NETFX_CORE + static readonly string MetroStyleDataPath = Windows.Storage.ApplicationData.Current.LocalFolder.Path; + + public static readonly string[] InMemoryDbPaths = new[] + { + ":memory:", + "file::memory:" + }; + + public static bool IsInMemoryPath(string databasePath) + { + return InMemoryDbPaths.Any(i => i.Equals(databasePath, StringComparison.OrdinalIgnoreCase)); + } + +#endif + + public SQLiteConnectionString(string databasePath, bool storeDateTimeAsTicks, object key) + { + ConnectionString = databasePath; + StoreDateTimeAsTicks = storeDateTimeAsTicks; + Key = key; + +#if NETFX_CORE + DatabasePath = IsInMemoryPath(databasePath) + ? databasePath + : System.IO.Path.Combine(MetroStyleDataPath, databasePath); + +#else + DatabasePath = databasePath; +#endif + } + } + + [AttributeUsage(AttributeTargets.Class)] + public class TableAttribute : Attribute + { + public string Name { get; set; } + + /// + /// Flag whether to create the table without rowid (see https://sqlite.org/withoutrowid.html) + /// + /// The default is false so that sqlite adds an implicit rowid to every table created. + /// + public bool WithoutRowId { get; set; } + + public TableAttribute(string name) + { + Name = name; + } + } + + [AttributeUsage(AttributeTargets.Property)] + public class ColumnAttribute : Attribute + { + public string Name { get; set; } + + public ColumnAttribute(string name) + { + Name = name; + } + } + + [AttributeUsage(AttributeTargets.Property)] + public class PrimaryKeyAttribute : Attribute + { + } + + [AttributeUsage(AttributeTargets.Property)] + public class AutoIncrementAttribute : Attribute + { + } + + [AttributeUsage(AttributeTargets.Property)] + public class IndexedAttribute : Attribute + { + public string Name { get; set; } + public int Order { get; set; } + public virtual bool Unique { get; set; } + + public IndexedAttribute() + { + } + + public IndexedAttribute(string name, int order) + { + Name = name; + Order = order; + } + } + + [AttributeUsage(AttributeTargets.Property)] + public class IgnoreAttribute : Attribute + { + } + + [AttributeUsage(AttributeTargets.Property)] + public class UniqueAttribute : IndexedAttribute + { + public override bool Unique + { + get { return true; } + set { /* throw? */ } + } + } + + [AttributeUsage(AttributeTargets.Property)] + public class MaxLengthAttribute : Attribute + { + public int Value { get; private set; } + + public MaxLengthAttribute(int length) + { + Value = length; + } + } + + public sealed class PreserveAttribute : System.Attribute + { + public bool AllMembers; + public bool Conditional; + } + + /// + /// Select the collating sequence to use on a column. + /// "BINARY", "NOCASE", and "RTRIM" are supported. + /// "BINARY" is the default. + /// + [AttributeUsage(AttributeTargets.Property)] + public class CollationAttribute : Attribute + { + public string Value { get; private set; } + + public CollationAttribute(string collation) + { + Value = collation; + } + } + + [AttributeUsage(AttributeTargets.Property)] + public class NotNullAttribute : Attribute + { + } + + [AttributeUsage(AttributeTargets.Enum)] + public class StoreAsTextAttribute : Attribute + { + } + + public class TableMapping + { + public Type MappedType { get; private set; } + + public string TableName { get; private set; } + + public bool WithoutRowId { get; private set; } + + public Column[] Columns { get; private set; } + + public Column PK { get; private set; } + + public string GetByPrimaryKeySql { get; private set; } + + public CreateFlags CreateFlags { get; private set; } + + readonly Column _autoPk; + readonly Column[] _insertColumns; + readonly Column[] _insertOrReplaceColumns; + + public TableMapping(Type type, CreateFlags createFlags = CreateFlags.None) + { + MappedType = type; + CreateFlags = createFlags; + + var typeInfo = type.GetTypeInfo(); + var tableAttr = + typeInfo.CustomAttributes + .Where(x => x.AttributeType == typeof(TableAttribute)) + .Select(x => (TableAttribute)Orm.InflateAttribute(x)) + .FirstOrDefault(); + + TableName = (tableAttr != null && !string.IsNullOrEmpty(tableAttr.Name)) ? tableAttr.Name : MappedType.Name; + WithoutRowId = tableAttr != null ? tableAttr.WithoutRowId : false; + + var props = new List(); + var baseType = type; + var propNames = new HashSet(); + while (baseType != typeof(object)) + { + var ti = baseType.GetTypeInfo(); + var newProps = ( + from p in ti.DeclaredProperties + where + !propNames.Contains(p.Name) && + p.CanRead && p.CanWrite && + (p.GetMethod != null) && (p.SetMethod != null) && + (p.GetMethod.IsPublic && p.SetMethod.IsPublic) && + (!p.GetMethod.IsStatic) && (!p.SetMethod.IsStatic) + select p).ToList(); + foreach (var p in newProps) + { + propNames.Add(p.Name); + } + props.AddRange(newProps); + baseType = ti.BaseType; + } + + var cols = new List(); + foreach (var p in props) + { + var ignore = p.IsDefined(typeof(IgnoreAttribute), true); + if (!ignore) + { + cols.Add(new Column(p, createFlags)); + } + } + Columns = cols.ToArray(); + foreach (var c in Columns) + { + if (c.IsAutoInc && c.IsPK) + { + _autoPk = c; + } + if (c.IsPK) + { + PK = c; + } + } + + HasAutoIncPK = _autoPk != null; + + if (PK != null) + { + GetByPrimaryKeySql = string.Format("select * from \"{0}\" where \"{1}\" = ?", TableName, PK.Name); + } + else + { + // People should not be calling Get/Find without a PK + GetByPrimaryKeySql = string.Format("select * from \"{0}\" limit 1", TableName); + } + + _insertColumns = Columns.Where(c => !c.IsAutoInc).ToArray(); + _insertOrReplaceColumns = Columns.ToArray(); + } + + public bool HasAutoIncPK { get; private set; } + + public void SetAutoIncPK(object obj, long id) + { + if (_autoPk != null) + { + _autoPk.SetValue(obj, Convert.ChangeType(id, _autoPk.ColumnType, null)); + } + } + + public Column[] InsertColumns + { + get + { + return _insertColumns; + } + } + + public Column[] InsertOrReplaceColumns + { + get + { + return _insertOrReplaceColumns; + } + } + + public Column FindColumnWithPropertyName(string propertyName) + { + var exact = Columns.FirstOrDefault(c => c.PropertyName == propertyName); + return exact; + } + + public Column FindColumn(string columnName) + { + var exact = Columns.FirstOrDefault(c => c.Name.ToLower() == columnName.ToLower()); + return exact; + } + + public class Column + { + PropertyInfo _prop; + + public string Name { get; private set; } + + public PropertyInfo PropertyInfo => _prop; + + public string PropertyName { get { return _prop.Name; } } + + public Type ColumnType { get; private set; } + + public string Collation { get; private set; } + + public bool IsAutoInc { get; private set; } + public bool IsAutoGuid { get; private set; } + + public bool IsPK { get; private set; } + + public IEnumerable Indices { get; set; } + + public bool IsNullable { get; private set; } + + public int? MaxStringLength { get; private set; } + + public bool StoreAsText { get; private set; } + + public Column(PropertyInfo prop, CreateFlags createFlags = CreateFlags.None) + { + var colAttr = prop.CustomAttributes.FirstOrDefault(x => x.AttributeType == typeof(ColumnAttribute)); + + _prop = prop; + Name = (colAttr != null && colAttr.ConstructorArguments.Count > 0) ? + colAttr.ConstructorArguments[0].Value?.ToString() : + prop.Name; + //If this type is Nullable then Nullable.GetUnderlyingType returns the T, otherwise it returns null, so get the actual type instead + ColumnType = Nullable.GetUnderlyingType(prop.PropertyType) ?? prop.PropertyType; + Collation = Orm.Collation(prop); + + IsPK = Orm.IsPK(prop) || + (((createFlags & CreateFlags.ImplicitPK) == CreateFlags.ImplicitPK) && + string.Compare(prop.Name, Orm.ImplicitPkName, StringComparison.OrdinalIgnoreCase) == 0); + + var isAuto = Orm.IsAutoInc(prop) || (IsPK && ((createFlags & CreateFlags.AutoIncPK) == CreateFlags.AutoIncPK)); + IsAutoGuid = isAuto && ColumnType == typeof(Guid); + IsAutoInc = isAuto && !IsAutoGuid; + + Indices = Orm.GetIndices(prop); + if (!Indices.Any() + && !IsPK + && ((createFlags & CreateFlags.ImplicitIndex) == CreateFlags.ImplicitIndex) + && Name.EndsWith(Orm.ImplicitIndexSuffix, StringComparison.OrdinalIgnoreCase) + ) + { + Indices = new IndexedAttribute[] { new IndexedAttribute() }; + } + IsNullable = !(IsPK || Orm.IsMarkedNotNull(prop)); + MaxStringLength = Orm.MaxStringLength(prop); + + StoreAsText = prop.PropertyType.GetTypeInfo().CustomAttributes.Any(x => x.AttributeType == typeof(StoreAsTextAttribute)); + } + + public void SetValue(object obj, object val) + { + if (val != null && ColumnType.GetTypeInfo().IsEnum) + { + _prop.SetValue(obj, Enum.ToObject(ColumnType, val)); + } + else + { + _prop.SetValue(obj, val, null); + } + } + + public object GetValue(object obj) + { + return _prop.GetValue(obj, null); + } + } + } + + class EnumCacheInfo + { + public EnumCacheInfo(Type type) + { + var typeInfo = type.GetTypeInfo(); + + IsEnum = typeInfo.IsEnum; + + if (IsEnum) + { + StoreAsText = typeInfo.CustomAttributes.Any(x => x.AttributeType == typeof(StoreAsTextAttribute)); + + if (StoreAsText) + { + EnumValues = new Dictionary(); + foreach (object e in Enum.GetValues(type)) + { + EnumValues[Convert.ToInt32(e)] = e.ToString(); + } + } + } + } + + public bool IsEnum { get; private set; } + + public bool StoreAsText { get; private set; } + + public Dictionary EnumValues { get; private set; } + } + + static class EnumCache + { + static readonly Dictionary Cache = new Dictionary(); + + public static EnumCacheInfo GetInfo() + { + return GetInfo(typeof(T)); + } + + public static EnumCacheInfo GetInfo(Type type) + { + lock (Cache) + { + EnumCacheInfo info = null; + if (!Cache.TryGetValue(type, out info)) + { + info = new EnumCacheInfo(type); + Cache[type] = info; + } + + return info; + } + } + } + + public static class Orm + { + public const int DefaultMaxStringLength = 140; + public const string ImplicitPkName = "Id"; + public const string ImplicitIndexSuffix = "Id"; + + public static Type GetType(object obj) + { + if (obj == null) + return typeof(object); + var rt = obj as IReflectableType; + if (rt != null) + return rt.GetTypeInfo().AsType(); + return obj.GetType(); + } + + public static string SqlDecl(TableMapping.Column p, bool storeDateTimeAsTicks) + { + string decl = "\"" + p.Name + "\" " + SqlType(p, storeDateTimeAsTicks) + " "; + + if (p.IsPK) + { + decl += "primary key "; + } + if (p.IsAutoInc) + { + decl += "autoincrement "; + } + if (!p.IsNullable) + { + decl += "not null "; + } + if (!string.IsNullOrEmpty(p.Collation)) + { + decl += "collate " + p.Collation + " "; + } + + return decl; + } + + public static string SqlType(TableMapping.Column p, bool storeDateTimeAsTicks) + { + var clrType = p.ColumnType; + if (clrType == typeof(Boolean) || clrType == typeof(Byte) || clrType == typeof(UInt16) || clrType == typeof(SByte) || clrType == typeof(Int16) || clrType == typeof(Int32) || clrType == typeof(UInt32) || clrType == typeof(Int64)) + { + return "integer"; + } + else if (clrType == typeof(Single) || clrType == typeof(Double) || clrType == typeof(Decimal)) + { + return "float"; + } + else if (clrType == typeof(String) || clrType == typeof(StringBuilder) || clrType == typeof(Uri) || clrType == typeof(UriBuilder)) + { + int? len = p.MaxStringLength; + + if (len.HasValue) + return "varchar(" + len.Value + ")"; + + return "varchar"; + } + else if (clrType == typeof(TimeSpan)) + { + return "bigint"; + } + else if (clrType == typeof(DateTime)) + { + return storeDateTimeAsTicks ? "bigint" : "datetime"; + } + else if (clrType == typeof(DateTimeOffset)) + { + return "bigint"; + } + else if (clrType.GetTypeInfo().IsEnum) + { + if (p.StoreAsText) + return "varchar"; + else + return "integer"; + } + else if (clrType == typeof(byte[])) + { + return "blob"; + } + else if (clrType == typeof(Guid)) + { + return "varchar(36)"; + } + else + { + throw new NotSupportedException("Don't know about " + clrType); + } + } + + public static bool IsPK(MemberInfo p) + { + return p.CustomAttributes.Any(x => x.AttributeType == typeof(PrimaryKeyAttribute)); + } + + public static string Collation(MemberInfo p) + { + return + (p.CustomAttributes + .Where(x => typeof(CollationAttribute) == x.AttributeType) + .Select(x => { + var args = x.ConstructorArguments; + return args.Count > 0 ? ((args[0].Value as string) ?? "") : ""; + }) + .FirstOrDefault()) ?? ""; + } + + public static bool IsAutoInc(MemberInfo p) + { + return p.CustomAttributes.Any(x => x.AttributeType == typeof(AutoIncrementAttribute)); + } + + public static FieldInfo GetField(TypeInfo t, string name) + { + var f = t.GetDeclaredField(name); + if (f != null) + return f; + return GetField(t.BaseType.GetTypeInfo(), name); + } + + public static PropertyInfo GetProperty(TypeInfo t, string name) + { + var f = t.GetDeclaredProperty(name); + if (f != null) + return f; + return GetProperty(t.BaseType.GetTypeInfo(), name); + } + + public static object InflateAttribute(CustomAttributeData x) + { + var atype = x.AttributeType; + var typeInfo = atype.GetTypeInfo(); + var args = x.ConstructorArguments.Select(a => a.Value).ToArray(); + var r = Activator.CreateInstance(x.AttributeType, args); + foreach (var arg in x.NamedArguments) + { + if (arg.IsField) + { + GetField(typeInfo, arg.MemberName).SetValue(r, arg.TypedValue.Value); + } + else + { + GetProperty(typeInfo, arg.MemberName).SetValue(r, arg.TypedValue.Value); + } + } + return r; + } + + public static IEnumerable GetIndices(MemberInfo p) + { + var indexedInfo = typeof(IndexedAttribute).GetTypeInfo(); + return + p.CustomAttributes + .Where(x => indexedInfo.IsAssignableFrom(x.AttributeType.GetTypeInfo())) + .Select(x => (IndexedAttribute)InflateAttribute(x)); + } + + public static int? MaxStringLength(PropertyInfo p) + { + var attr = p.CustomAttributes.FirstOrDefault(x => x.AttributeType == typeof(MaxLengthAttribute)); + if (attr != null) + { + var attrv = (MaxLengthAttribute)InflateAttribute(attr); + return attrv.Value; + } + return null; + } + + public static bool IsMarkedNotNull(MemberInfo p) + { + return p.CustomAttributes.Any(x => x.AttributeType == typeof(NotNullAttribute)); + } + } + + public partial class SQLiteCommand + { + SQLiteConnection _conn; + private List _bindings; + + public string CommandText { get; set; } + + public SQLiteCommand(SQLiteConnection conn) + { + _conn = conn; + _bindings = new List(); + CommandText = ""; + } + + public int ExecuteNonQuery() + { + if (_conn.Trace) + { + _conn.Tracer?.Invoke("Executing: " + this); + } + + var r = SQLite3.Result.OK; + var stmt = Prepare(); + r = SQLite3.Step(stmt); + Finalize(stmt); + if (r == SQLite3.Result.Done) + { + int rowsAffected = SQLite3.Changes(_conn.Handle); + return rowsAffected; + } + else if (r == SQLite3.Result.Error) + { + string msg = SQLite3.GetErrmsg(_conn.Handle); + throw SQLiteException.New(r, msg); + } + else if (r == SQLite3.Result.Constraint) + { + if (SQLite3.ExtendedErrCode(_conn.Handle) == SQLite3.ExtendedResult.ConstraintNotNull) + { + throw NotNullConstraintViolationException.New(r, SQLite3.GetErrmsg(_conn.Handle)); + } + } + + throw SQLiteException.New(r, r.ToString()); + } + + public IEnumerable ExecuteDeferredQuery() + { + return ExecuteDeferredQuery(_conn.GetMapping(typeof(T))); + } + + public List ExecuteQuery() + { + return ExecuteDeferredQuery(_conn.GetMapping(typeof(T))).ToList(); + } + + public List ExecuteQuery(TableMapping map) + { + return ExecuteDeferredQuery(map).ToList(); + } + + /// + /// Invoked every time an instance is loaded from the database. + /// + /// + /// The newly created object. + /// + /// + /// This can be overridden in combination with the + /// method to hook into the life-cycle of objects. + /// + protected virtual void OnInstanceCreated(object obj) + { + // Can be overridden. + } + + public IEnumerable ExecuteDeferredQuery(TableMapping map) + { + if (_conn.Trace) + { + _conn.Tracer?.Invoke("Executing Query: " + this); + } + + var stmt = Prepare(); + try + { + var cols = new TableMapping.Column[SQLite3.ColumnCount(stmt)]; + + for (int i = 0; i < cols.Length; i++) + { + var name = SQLite3.ColumnName16(stmt, i); + cols[i] = map.FindColumn(name); + } + + while (SQLite3.Step(stmt) == SQLite3.Result.Row) + { + var obj = Activator.CreateInstance(map.MappedType); + for (int i = 0; i < cols.Length; i++) + { + if (cols[i] == null) + continue; + var colType = SQLite3.ColumnType(stmt, i); + var val = ReadCol(stmt, i, colType, cols[i].ColumnType); + cols[i].SetValue(obj, val); + } + OnInstanceCreated(obj); + yield return (T)obj; + } + } + finally + { + SQLite3.Finalize(stmt); + } + } + + public T ExecuteScalar() + { + if (_conn.Trace) + { + _conn.Tracer?.Invoke("Executing Query: " + this); + } + + T val = default(T); + + var stmt = Prepare(); + + try + { + var r = SQLite3.Step(stmt); + if (r == SQLite3.Result.Row) + { + var colType = SQLite3.ColumnType(stmt, 0); + val = (T)ReadCol(stmt, 0, colType, typeof(T)); + } + else if (r == SQLite3.Result.Done) + { + } + else + { + throw SQLiteException.New(r, SQLite3.GetErrmsg(_conn.Handle)); + } + } + finally + { + Finalize(stmt); + } + + return val; + } + + public void Bind(string name, object val) + { + _bindings.Add(new Binding + { + Name = name, + Value = val + }); + } + + public void Bind(object val) + { + Bind(null, val); + } + + public override string ToString() + { + var parts = new string[1 + _bindings.Count]; + parts[0] = CommandText; + var i = 1; + foreach (var b in _bindings) + { + parts[i] = string.Format(" {0}: {1}", i - 1, b.Value); + i++; + } + return string.Join(Environment.NewLine, parts); + } + + Sqlite3Statement Prepare() + { + var stmt = SQLite3.Prepare2(_conn.Handle, CommandText); + BindAll(stmt); + return stmt; + } + + void Finalize(Sqlite3Statement stmt) + { + SQLite3.Finalize(stmt); + } + + void BindAll(Sqlite3Statement stmt) + { + int nextIdx = 1; + foreach (var b in _bindings) + { + if (b.Name != null) + { + b.Index = SQLite3.BindParameterIndex(stmt, b.Name); + } + else + { + b.Index = nextIdx++; + } + + BindParameter(stmt, b.Index, b.Value, _conn.StoreDateTimeAsTicks); + } + } + + static IntPtr NegativePointer = new IntPtr(-1); + + const string DateTimeExactStoreFormat = "yyyy'-'MM'-'dd'T'HH':'mm':'ss'.'fff"; + + internal static void BindParameter(Sqlite3Statement stmt, int index, object value, bool storeDateTimeAsTicks) + { + if (value == null) + { + SQLite3.BindNull(stmt, index); + } + else + { + if (value is Int32) + { + SQLite3.BindInt(stmt, index, (int)value); + } + else if (value is String) + { + SQLite3.BindText(stmt, index, (string)value, -1, NegativePointer); + } + else if (value is Byte || value is UInt16 || value is SByte || value is Int16) + { + SQLite3.BindInt(stmt, index, Convert.ToInt32(value)); + } + else if (value is Boolean) + { + SQLite3.BindInt(stmt, index, (bool)value ? 1 : 0); + } + else if (value is UInt32 || value is Int64) + { + SQLite3.BindInt64(stmt, index, Convert.ToInt64(value)); + } + else if (value is Single || value is Double || value is Decimal) + { + SQLite3.BindDouble(stmt, index, Convert.ToDouble(value)); + } + else if (value is TimeSpan) + { + SQLite3.BindInt64(stmt, index, ((TimeSpan)value).Ticks); + } + else if (value is DateTime) + { + if (storeDateTimeAsTicks) + { + SQLite3.BindInt64(stmt, index, ((DateTime)value).Ticks); + } + else + { + SQLite3.BindText(stmt, index, ((DateTime)value).ToString(DateTimeExactStoreFormat, System.Globalization.CultureInfo.InvariantCulture), -1, NegativePointer); + } + } + else if (value is DateTimeOffset) + { + SQLite3.BindInt64(stmt, index, ((DateTimeOffset)value).UtcTicks); + } + else if (value is byte[]) + { + SQLite3.BindBlob(stmt, index, (byte[])value, ((byte[])value).Length, NegativePointer); + } + else if (value is Guid) + { + SQLite3.BindText(stmt, index, ((Guid)value).ToString(), 72, NegativePointer); + } + else if (value is Uri) + { + SQLite3.BindText(stmt, index, ((Uri)value).ToString(), -1, NegativePointer); + } + else if (value is StringBuilder) + { + SQLite3.BindText(stmt, index, ((StringBuilder)value).ToString(), -1, NegativePointer); + } + else if (value is UriBuilder) + { + SQLite3.BindText(stmt, index, ((UriBuilder)value).ToString(), -1, NegativePointer); + } + else + { + // Now we could possibly get an enum, retrieve cached info + var valueType = value.GetType(); + var enumInfo = EnumCache.GetInfo(valueType); + if (enumInfo.IsEnum) + { + var enumIntValue = Convert.ToInt32(value); + if (enumInfo.StoreAsText) + SQLite3.BindText(stmt, index, enumInfo.EnumValues[enumIntValue], -1, NegativePointer); + else + SQLite3.BindInt(stmt, index, enumIntValue); + } + else + { + throw new NotSupportedException("Cannot store type: " + Orm.GetType(value)); + } + } + } + } + + class Binding + { + public string Name { get; set; } + + public object Value { get; set; } + + public int Index { get; set; } + } + + object ReadCol(Sqlite3Statement stmt, int index, SQLite3.ColType type, Type clrType) + { + if (type == SQLite3.ColType.Null) + { + return null; + } + else + { + var clrTypeInfo = clrType.GetTypeInfo(); + if (clrType == typeof(String)) + { + return SQLite3.ColumnString(stmt, index); + } + else if (clrType == typeof(Int32)) + { + return (int)SQLite3.ColumnInt(stmt, index); + } + else if (clrType == typeof(Boolean)) + { + return SQLite3.ColumnInt(stmt, index) == 1; + } + else if (clrType == typeof(double)) + { + return SQLite3.ColumnDouble(stmt, index); + } + else if (clrType == typeof(float)) + { + return (float)SQLite3.ColumnDouble(stmt, index); + } + else if (clrType == typeof(TimeSpan)) + { + return new TimeSpan(SQLite3.ColumnInt64(stmt, index)); + } + else if (clrType == typeof(DateTime)) + { + if (_conn.StoreDateTimeAsTicks) + { + return new DateTime(SQLite3.ColumnInt64(stmt, index)); + } + else + { + var text = SQLite3.ColumnString(stmt, index); + DateTime resultDate; + if (!DateTime.TryParseExact(text, DateTimeExactStoreFormat, System.Globalization.CultureInfo.InvariantCulture, System.Globalization.DateTimeStyles.None, out resultDate)) + { + resultDate = DateTime.Parse(text); + } + return resultDate; + } + } + else if (clrType == typeof(DateTimeOffset)) + { + return new DateTimeOffset(SQLite3.ColumnInt64(stmt, index), TimeSpan.Zero); + } + else if (clrTypeInfo.IsEnum) + { + if (type == SQLite3.ColType.Text) + { + var value = SQLite3.ColumnString(stmt, index); + return Enum.Parse(clrType, value.ToString(), true); + } + else + return SQLite3.ColumnInt(stmt, index); + } + else if (clrType == typeof(Int64)) + { + return SQLite3.ColumnInt64(stmt, index); + } + else if (clrType == typeof(UInt32)) + { + return (uint)SQLite3.ColumnInt64(stmt, index); + } + else if (clrType == typeof(decimal)) + { + return (decimal)SQLite3.ColumnDouble(stmt, index); + } + else if (clrType == typeof(Byte)) + { + return (byte)SQLite3.ColumnInt(stmt, index); + } + else if (clrType == typeof(UInt16)) + { + return (ushort)SQLite3.ColumnInt(stmt, index); + } + else if (clrType == typeof(Int16)) + { + return (short)SQLite3.ColumnInt(stmt, index); + } + else if (clrType == typeof(sbyte)) + { + return (sbyte)SQLite3.ColumnInt(stmt, index); + } + else if (clrType == typeof(byte[])) + { + return SQLite3.ColumnByteArray(stmt, index); + } + else if (clrType == typeof(Guid)) + { + var text = SQLite3.ColumnString(stmt, index); + return new Guid(text); + } + else if (clrType == typeof(Uri)) + { + var text = SQLite3.ColumnString(stmt, index); + return new Uri(text); + } + else if (clrType == typeof(StringBuilder)) + { + var text = SQLite3.ColumnString(stmt, index); + return new StringBuilder(text); + } + else if (clrType == typeof(UriBuilder)) + { + var text = SQLite3.ColumnString(stmt, index); + return new UriBuilder(text); + } + else + { + throw new NotSupportedException("Don't know how to read " + clrType); + } + } + } + } + + /// + /// Since the insert never changed, we only need to prepare once. + /// + class PreparedSqlLiteInsertCommand : IDisposable + { + bool Initialized; + + SQLiteConnection Connection; + + string CommandText; + + Sqlite3Statement Statement; + static readonly Sqlite3Statement NullStatement = default(Sqlite3Statement); + + public PreparedSqlLiteInsertCommand(SQLiteConnection conn, string commandText) + { + Connection = conn; + CommandText = commandText; + } + + public int ExecuteNonQuery(object[] source) + { + if (Initialized && Statement == NullStatement) + { + throw new ObjectDisposedException(nameof(PreparedSqlLiteInsertCommand)); + } + + if (Connection.Trace) + { + Connection.Tracer?.Invoke("Executing: " + CommandText); + } + + var r = SQLite3.Result.OK; + + if (!Initialized) + { + Statement = SQLite3.Prepare2(Connection.Handle, CommandText); + Initialized = true; + } + + //bind the values. + if (source != null) + { + for (int i = 0; i < source.Length; i++) + { + SQLiteCommand.BindParameter(Statement, i + 1, source[i], Connection.StoreDateTimeAsTicks); + } + } + r = SQLite3.Step(Statement); + + if (r == SQLite3.Result.Done) + { + int rowsAffected = SQLite3.Changes(Connection.Handle); + SQLite3.Reset(Statement); + return rowsAffected; + } + else if (r == SQLite3.Result.Error) + { + string msg = SQLite3.GetErrmsg(Connection.Handle); + SQLite3.Reset(Statement); + throw SQLiteException.New(r, msg); + } + else if (r == SQLite3.Result.Constraint && SQLite3.ExtendedErrCode(Connection.Handle) == SQLite3.ExtendedResult.ConstraintNotNull) + { + SQLite3.Reset(Statement); + throw NotNullConstraintViolationException.New(r, SQLite3.GetErrmsg(Connection.Handle)); + } + else + { + SQLite3.Reset(Statement); + throw SQLiteException.New(r, r.ToString()); + } + } + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + void Dispose(bool disposing) + { + var s = Statement; + Statement = NullStatement; + Connection = null; + if (s != NullStatement) + { + SQLite3.Finalize(s); + } + } + + ~PreparedSqlLiteInsertCommand() + { + Dispose(false); + } + } + + public enum CreateTableResult + { + Created, + Migrated, + } + + public class CreateTablesResult + { + public Dictionary Results { get; private set; } + + public CreateTablesResult() + { + Results = new Dictionary(); + } + } + + public abstract class BaseTableQuery + { + protected class Ordering + { + public string ColumnName { get; set; } + public bool Ascending { get; set; } + } + } + + public class TableQuery : BaseTableQuery, IEnumerable + { + public SQLiteConnection Connection { get; private set; } + + public TableMapping Table { get; private set; } + + Expression _where; + List _orderBys; + int? _limit; + int? _offset; + + BaseTableQuery _joinInner; + Expression _joinInnerKeySelector; + BaseTableQuery _joinOuter; + Expression _joinOuterKeySelector; + Expression _joinSelector; + + Expression _selector; + + TableQuery(SQLiteConnection conn, TableMapping table) + { + Connection = conn; + Table = table; + } + + public TableQuery(SQLiteConnection conn) + { + Connection = conn; + Table = Connection.GetMapping(typeof(T)); + } + + public TableQuery Clone() + { + var q = new TableQuery(Connection, Table); + q._where = _where; + q._deferred = _deferred; + if (_orderBys != null) + { + q._orderBys = new List(_orderBys); + } + q._limit = _limit; + q._offset = _offset; + q._joinInner = _joinInner; + q._joinInnerKeySelector = _joinInnerKeySelector; + q._joinOuter = _joinOuter; + q._joinOuterKeySelector = _joinOuterKeySelector; + q._joinSelector = _joinSelector; + q._selector = _selector; + return q; + } + + /// + /// Filters the query based on a predicate. + /// + public TableQuery Where(Expression> predExpr) + { + if (predExpr.NodeType == ExpressionType.Lambda) + { + var lambda = (LambdaExpression)predExpr; + var pred = lambda.Body; + var q = Clone(); + q.AddWhere(pred); + return q; + } + else + { + throw new NotSupportedException("Must be a predicate"); + } + } + + /// + /// Delete all the rows that match this query. + /// + public int Delete() + { + return Delete(null); + } + + /// + /// Delete all the rows that match this query and the given predicate. + /// + public int Delete(Expression> predExpr) + { + if (_limit.HasValue || _offset.HasValue) + throw new InvalidOperationException("Cannot delete with limits or offsets"); + + if (_where == null && predExpr == null) + throw new InvalidOperationException("No condition specified"); + + var pred = _where; + + if (predExpr != null && predExpr.NodeType == ExpressionType.Lambda) + { + var lambda = (LambdaExpression)predExpr; + pred = pred != null ? Expression.AndAlso(pred, lambda.Body) : lambda.Body; + } + + var args = new List(); + var cmdText = "delete from \"" + Table.TableName + "\""; + var w = CompileExpr(pred, args); + cmdText += " where " + w.CommandText; + + var command = Connection.CreateCommand(cmdText, args.ToArray()); + + int result = command.ExecuteNonQuery(); + return result; + } + + /// + /// Yields a given number of elements from the query and then skips the remainder. + /// + public TableQuery Take(int n) + { + var q = Clone(); + q._limit = n; + return q; + } + + /// + /// Skips a given number of elements from the query and then yields the remainder. + /// + public TableQuery Skip(int n) + { + var q = Clone(); + q._offset = n; + return q; + } + + /// + /// Returns the element at a given index + /// + public T ElementAt(int index) + { + return Skip(index).Take(1).First(); + } + + bool _deferred; + public TableQuery Deferred() + { + var q = Clone(); + q._deferred = true; + return q; + } + + /// + /// Order the query results according to a key. + /// + public TableQuery OrderBy(Expression> orderExpr) + { + return AddOrderBy(orderExpr, true); + } + + /// + /// Order the query results according to a key. + /// + public TableQuery OrderByDescending(Expression> orderExpr) + { + return AddOrderBy(orderExpr, false); + } + + /// + /// Order the query results according to a key. + /// + public TableQuery ThenBy(Expression> orderExpr) + { + return AddOrderBy(orderExpr, true); + } + + /// + /// Order the query results according to a key. + /// + public TableQuery ThenByDescending(Expression> orderExpr) + { + return AddOrderBy(orderExpr, false); + } + + TableQuery AddOrderBy(Expression> orderExpr, bool asc) + { + if (orderExpr.NodeType == ExpressionType.Lambda) + { + var lambda = (LambdaExpression)orderExpr; + + MemberExpression mem = null; + + var unary = lambda.Body as UnaryExpression; + if (unary != null && unary.NodeType == ExpressionType.Convert) + { + mem = unary.Operand as MemberExpression; + } + else + { + mem = lambda.Body as MemberExpression; + } + + if (mem != null && (mem.Expression.NodeType == ExpressionType.Parameter)) + { + var q = Clone(); + if (q._orderBys == null) + { + q._orderBys = new List(); + } + q._orderBys.Add(new Ordering + { + ColumnName = Table.FindColumnWithPropertyName(mem.Member.Name).Name, + Ascending = asc + }); + return q; + } + else + { + throw new NotSupportedException("Order By does not support: " + orderExpr); + } + } + else + { + throw new NotSupportedException("Must be a predicate"); + } + } + + private void AddWhere(Expression pred) + { + if (_where == null) + { + _where = pred; + } + else + { + _where = Expression.AndAlso(_where, pred); + } + } + + ///// + ///// Performs an inner join of two queries based on matching keys extracted from the elements. + ///// + //public TableQuery Join ( + // TableQuery inner, + // Expression> outerKeySelector, + // Expression> innerKeySelector, + // Expression> resultSelector) + //{ + // var q = new TableQuery (Connection, Connection.GetMapping (typeof (TResult))) { + // _joinOuter = this, + // _joinOuterKeySelector = outerKeySelector, + // _joinInner = inner, + // _joinInnerKeySelector = innerKeySelector, + // _joinSelector = resultSelector, + // }; + // return q; + //} + + // Not needed until Joins are supported + // Keeping this commented out forces the default Linq to objects processor to run + //public TableQuery Select (Expression> selector) + //{ + // var q = Clone (); + // q._selector = selector; + // return q; + //} + + private SQLiteCommand GenerateCommand(string selectionList) + { + if (_joinInner != null && _joinOuter != null) + { + throw new NotSupportedException("Joins are not supported."); + } + else + { + var cmdText = "select " + selectionList + " from \"" + Table.TableName + "\""; + var args = new List(); + if (_where != null) + { + var w = CompileExpr(_where, args); + cmdText += " where " + w.CommandText; + } + if ((_orderBys != null) && (_orderBys.Count > 0)) + { + var t = string.Join(", ", _orderBys.Select(o => "\"" + o.ColumnName + "\"" + (o.Ascending ? "" : " desc")).ToArray()); + cmdText += " order by " + t; + } + if (_limit.HasValue) + { + cmdText += " limit " + _limit.Value; + } + if (_offset.HasValue) + { + if (!_limit.HasValue) + { + cmdText += " limit -1 "; + } + cmdText += " offset " + _offset.Value; + } + return Connection.CreateCommand(cmdText, args.ToArray()); + } + } + + class CompileResult + { + public string CommandText { get; set; } + + public object Value { get; set; } + } + + private CompileResult CompileExpr(Expression expr, List queryArgs) + { + if (expr == null) + { + throw new NotSupportedException("Expression is NULL"); + } + else if (expr is BinaryExpression) + { + var bin = (BinaryExpression)expr; + + // VB turns 'x=="foo"' into 'CompareString(x,"foo",true/false)==0', so we need to unwrap it + // http://blogs.msdn.com/b/vbteam/archive/2007/09/18/vb-expression-trees-string-comparisons.aspx + if (bin.Left.NodeType == ExpressionType.Call) + { + var call = (MethodCallExpression)bin.Left; + if (call.Method.DeclaringType.FullName == "Microsoft.VisualBasic.CompilerServices.Operators" + && call.Method.Name == "CompareString") + bin = Expression.MakeBinary(bin.NodeType, call.Arguments[0], call.Arguments[1]); + } + + + var leftr = CompileExpr(bin.Left, queryArgs); + var rightr = CompileExpr(bin.Right, queryArgs); + + //If either side is a parameter and is null, then handle the other side specially (for "is null"/"is not null") + string text; + if (leftr.CommandText == "?" && leftr.Value == null) + text = CompileNullBinaryExpression(bin, rightr); + else if (rightr.CommandText == "?" && rightr.Value == null) + text = CompileNullBinaryExpression(bin, leftr); + else + text = "(" + leftr.CommandText + " " + GetSqlName(bin) + " " + rightr.CommandText + ")"; + return new CompileResult { CommandText = text }; + } + else if (expr.NodeType == ExpressionType.Not) + { + var operandExpr = ((UnaryExpression)expr).Operand; + var opr = CompileExpr(operandExpr, queryArgs); + object val = opr.Value; + if (val is bool) + val = !((bool)val); + return new CompileResult + { + CommandText = "NOT(" + opr.CommandText + ")", + Value = val + }; + } + else if (expr.NodeType == ExpressionType.Call) + { + + var call = (MethodCallExpression)expr; + var args = new CompileResult[call.Arguments.Count]; + var obj = call.Object != null ? CompileExpr(call.Object, queryArgs) : null; + + for (var i = 0; i < args.Length; i++) + { + args[i] = CompileExpr(call.Arguments[i], queryArgs); + } + + var sqlCall = ""; + + if (call.Method.Name == "Like" && args.Length == 2) + { + sqlCall = "(" + args[0].CommandText + " like " + args[1].CommandText + ")"; + } + else if (call.Method.Name == "Contains" && args.Length == 2) + { + sqlCall = "(" + args[1].CommandText + " in " + args[0].CommandText + ")"; + } + else if (call.Method.Name == "Contains" && args.Length == 1) + { + if (call.Object != null && call.Object.Type == typeof(string)) + { + sqlCall = "( instr(" + obj.CommandText + "," + args[0].CommandText + ") >0 )"; + } + else + { + sqlCall = "(" + args[0].CommandText + " in " + obj.CommandText + ")"; + } + } + else if (call.Method.Name == "StartsWith" && args.Length >= 1) + { + var startsWithCmpOp = StringComparison.CurrentCulture; + if (args.Length == 2) + { + startsWithCmpOp = (StringComparison)args[1].Value; + } + switch (startsWithCmpOp) + { + case StringComparison.Ordinal: + case StringComparison.CurrentCulture: + sqlCall = "( substr(" + obj.CommandText + ", 1, " + args[0].Value.ToString().Length + ") = " + args[0].CommandText + ")"; + break; + case StringComparison.OrdinalIgnoreCase: + case StringComparison.CurrentCultureIgnoreCase: + sqlCall = "(" + obj.CommandText + " like (" + args[0].CommandText + " || '%'))"; + break; + } + + } + else if (call.Method.Name == "EndsWith" && args.Length >= 1) + { + var endsWithCmpOp = StringComparison.CurrentCulture; + if (args.Length == 2) + { + endsWithCmpOp = (StringComparison)args[1].Value; + } + switch (endsWithCmpOp) + { + case StringComparison.Ordinal: + case StringComparison.CurrentCulture: + sqlCall = "( substr(" + obj.CommandText + ", length(" + obj.CommandText + ") - " + args[0].Value.ToString().Length + "+1, " + args[0].Value.ToString().Length + ") = " + args[0].CommandText + ")"; + break; + case StringComparison.OrdinalIgnoreCase: + case StringComparison.CurrentCultureIgnoreCase: + sqlCall = "(" + obj.CommandText + " like ('%' || " + args[0].CommandText + "))"; + break; + } + } + else if (call.Method.Name == "Equals" && args.Length == 1) + { + sqlCall = "(" + obj.CommandText + " = (" + args[0].CommandText + "))"; + } + else if (call.Method.Name == "ToLower") + { + sqlCall = "(lower(" + obj.CommandText + "))"; + } + else if (call.Method.Name == "ToUpper") + { + sqlCall = "(upper(" + obj.CommandText + "))"; + } + else if (call.Method.Name == "Replace" && args.Length == 2) + { + sqlCall = "(replace(" + obj.CommandText + "," + args[0].CommandText + "," + args[1].CommandText + "))"; + } + else + { + sqlCall = call.Method.Name.ToLower() + "(" + string.Join(",", args.Select(a => a.CommandText).ToArray()) + ")"; + } + return new CompileResult { CommandText = sqlCall }; + + } + else if (expr.NodeType == ExpressionType.Constant) + { + var c = (ConstantExpression)expr; + queryArgs.Add(c.Value); + return new CompileResult + { + CommandText = "?", + Value = c.Value + }; + } + else if (expr.NodeType == ExpressionType.Convert) + { + var u = (UnaryExpression)expr; + var ty = u.Type; + var valr = CompileExpr(u.Operand, queryArgs); + return new CompileResult + { + CommandText = valr.CommandText, + Value = valr.Value != null ? ConvertTo(valr.Value, ty) : null + }; + } + else if (expr.NodeType == ExpressionType.MemberAccess) + { + var mem = (MemberExpression)expr; + + var paramExpr = mem.Expression as ParameterExpression; + if (paramExpr == null) + { + var convert = mem.Expression as UnaryExpression; + if (convert != null && convert.NodeType == ExpressionType.Convert) + { + paramExpr = convert.Operand as ParameterExpression; + } + } + + if (paramExpr != null) + { + // + // This is a column of our table, output just the column name + // Need to translate it if that column name is mapped + // + var columnName = Table.FindColumnWithPropertyName(mem.Member.Name).Name; + return new CompileResult { CommandText = "\"" + columnName + "\"" }; + } + else + { + object obj = null; + if (mem.Expression != null) + { + var r = CompileExpr(mem.Expression, queryArgs); + if (r.Value == null) + { + throw new NotSupportedException("Member access failed to compile expression"); + } + if (r.CommandText == "?") + { + queryArgs.RemoveAt(queryArgs.Count - 1); + } + obj = r.Value; + } + + // + // Get the member value + // + object val = null; + + if (mem.Member is PropertyInfo) + { + var m = (PropertyInfo)mem.Member; + val = m.GetValue(obj, null); + } + else if (mem.Member is FieldInfo) + { + var m = (FieldInfo)mem.Member; + val = m.GetValue(obj); + } + else + { + throw new NotSupportedException("MemberExpr: " + mem.Member.GetType()); + } + + // + // Work special magic for enumerables + // + if (val != null && val is System.Collections.IEnumerable && !(val is string) && !(val is System.Collections.Generic.IEnumerable)) + { + var sb = new System.Text.StringBuilder(); + sb.Append("("); + var head = ""; + foreach (var a in (System.Collections.IEnumerable)val) + { + queryArgs.Add(a); + sb.Append(head); + sb.Append("?"); + head = ","; + } + sb.Append(")"); + return new CompileResult + { + CommandText = sb.ToString(), + Value = val + }; + } + else + { + queryArgs.Add(val); + return new CompileResult + { + CommandText = "?", + Value = val + }; + } + } + } + throw new NotSupportedException("Cannot compile: " + expr.NodeType.ToString()); + } + + static object ConvertTo(object obj, Type t) + { + Type nut = Nullable.GetUnderlyingType(t); + + if (nut != null) + { + if (obj == null) return null; + return Convert.ChangeType(obj, nut); + } + else + { + return Convert.ChangeType(obj, t); + } + } + + /// + /// Compiles a BinaryExpression where one of the parameters is null. + /// + /// The expression to compile + /// The non-null parameter + private string CompileNullBinaryExpression(BinaryExpression expression, CompileResult parameter) + { + if (expression.NodeType == ExpressionType.Equal) + return "(" + parameter.CommandText + " is ?)"; + else if (expression.NodeType == ExpressionType.NotEqual) + return "(" + parameter.CommandText + " is not ?)"; + else if (expression.NodeType == ExpressionType.GreaterThan + || expression.NodeType == ExpressionType.GreaterThanOrEqual + || expression.NodeType == ExpressionType.LessThan + || expression.NodeType == ExpressionType.LessThanOrEqual) + return "(" + parameter.CommandText + " < ?)"; // always false + else + throw new NotSupportedException("Cannot compile Null-BinaryExpression with type " + expression.NodeType.ToString()); + } + + string GetSqlName(Expression expr) + { + var n = expr.NodeType; + if (n == ExpressionType.GreaterThan) + return ">"; + else if (n == ExpressionType.GreaterThanOrEqual) + { + return ">="; + } + else if (n == ExpressionType.LessThan) + { + return "<"; + } + else if (n == ExpressionType.LessThanOrEqual) + { + return "<="; + } + else if (n == ExpressionType.And) + { + return "&"; + } + else if (n == ExpressionType.AndAlso) + { + return "and"; + } + else if (n == ExpressionType.Or) + { + return "|"; + } + else if (n == ExpressionType.OrElse) + { + return "or"; + } + else if (n == ExpressionType.Equal) + { + return "="; + } + else if (n == ExpressionType.NotEqual) + { + return "!="; + } + else + { + throw new NotSupportedException("Cannot get SQL for: " + n); + } + } + + /// + /// Execute SELECT COUNT(*) on the query + /// + public int Count() + { + return GenerateCommand("count(*)").ExecuteScalar(); + } + + /// + /// Execute SELECT COUNT(*) on the query with an additional WHERE clause. + /// + public int Count(Expression> predExpr) + { + return Where(predExpr).Count(); + } + + public IEnumerator GetEnumerator() + { + if (!_deferred) + return GenerateCommand("*").ExecuteQuery().GetEnumerator(); + + return GenerateCommand("*").ExecuteDeferredQuery().GetEnumerator(); + } + + System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + + /// + /// Queries the database and returns the results as a List. + /// + public List ToList() + { + return GenerateCommand("*").ExecuteQuery(); + } + + /// + /// Queries the database and returns the results as an array. + /// + public T[] ToArray() + { + return GenerateCommand("*").ExecuteQuery().ToArray(); + } + + /// + /// Returns the first element of this query. + /// + public T First() + { + var query = Take(1); + return query.ToList().First(); + } + + /// + /// Returns the first element of this query, or null if no element is found. + /// + public T FirstOrDefault() + { + var query = Take(1); + return query.ToList().FirstOrDefault(); + } + + /// + /// Returns the first element of this query that matches the predicate. + /// + public T First(Expression> predExpr) + { + return Where(predExpr).First(); + } + + /// + /// Returns the first element of this query that matches the predicate, or null + /// if no element is found. + /// + public T FirstOrDefault(Expression> predExpr) + { + return Where(predExpr).FirstOrDefault(); + } + } + + public static class SQLite3 + { + public enum Result : int + { + OK = 0, + Error = 1, + Internal = 2, + Perm = 3, + Abort = 4, + Busy = 5, + Locked = 6, + NoMem = 7, + ReadOnly = 8, + Interrupt = 9, + IOError = 10, + Corrupt = 11, + NotFound = 12, + Full = 13, + CannotOpen = 14, + LockErr = 15, + Empty = 16, + SchemaChngd = 17, + TooBig = 18, + Constraint = 19, + Mismatch = 20, + Misuse = 21, + NotImplementedLFS = 22, + AccessDenied = 23, + Format = 24, + Range = 25, + NonDBFile = 26, + Notice = 27, + Warning = 28, + Row = 100, + Done = 101 + } + + public enum ExtendedResult : int + { + IOErrorRead = (Result.IOError | (1 << 8)), + IOErrorShortRead = (Result.IOError | (2 << 8)), + IOErrorWrite = (Result.IOError | (3 << 8)), + IOErrorFsync = (Result.IOError | (4 << 8)), + IOErrorDirFSync = (Result.IOError | (5 << 8)), + IOErrorTruncate = (Result.IOError | (6 << 8)), + IOErrorFStat = (Result.IOError | (7 << 8)), + IOErrorUnlock = (Result.IOError | (8 << 8)), + IOErrorRdlock = (Result.IOError | (9 << 8)), + IOErrorDelete = (Result.IOError | (10 << 8)), + IOErrorBlocked = (Result.IOError | (11 << 8)), + IOErrorNoMem = (Result.IOError | (12 << 8)), + IOErrorAccess = (Result.IOError | (13 << 8)), + IOErrorCheckReservedLock = (Result.IOError | (14 << 8)), + IOErrorLock = (Result.IOError | (15 << 8)), + IOErrorClose = (Result.IOError | (16 << 8)), + IOErrorDirClose = (Result.IOError | (17 << 8)), + IOErrorSHMOpen = (Result.IOError | (18 << 8)), + IOErrorSHMSize = (Result.IOError | (19 << 8)), + IOErrorSHMLock = (Result.IOError | (20 << 8)), + IOErrorSHMMap = (Result.IOError | (21 << 8)), + IOErrorSeek = (Result.IOError | (22 << 8)), + IOErrorDeleteNoEnt = (Result.IOError | (23 << 8)), + IOErrorMMap = (Result.IOError | (24 << 8)), + LockedSharedcache = (Result.Locked | (1 << 8)), + BusyRecovery = (Result.Busy | (1 << 8)), + CannottOpenNoTempDir = (Result.CannotOpen | (1 << 8)), + CannotOpenIsDir = (Result.CannotOpen | (2 << 8)), + CannotOpenFullPath = (Result.CannotOpen | (3 << 8)), + CorruptVTab = (Result.Corrupt | (1 << 8)), + ReadonlyRecovery = (Result.ReadOnly | (1 << 8)), + ReadonlyCannotLock = (Result.ReadOnly | (2 << 8)), + ReadonlyRollback = (Result.ReadOnly | (3 << 8)), + AbortRollback = (Result.Abort | (2 << 8)), + ConstraintCheck = (Result.Constraint | (1 << 8)), + ConstraintCommitHook = (Result.Constraint | (2 << 8)), + ConstraintForeignKey = (Result.Constraint | (3 << 8)), + ConstraintFunction = (Result.Constraint | (4 << 8)), + ConstraintNotNull = (Result.Constraint | (5 << 8)), + ConstraintPrimaryKey = (Result.Constraint | (6 << 8)), + ConstraintTrigger = (Result.Constraint | (7 << 8)), + ConstraintUnique = (Result.Constraint | (8 << 8)), + ConstraintVTab = (Result.Constraint | (9 << 8)), + NoticeRecoverWAL = (Result.Notice | (1 << 8)), + NoticeRecoverRollback = (Result.Notice | (2 << 8)) + } + + + public enum ConfigOption : int + { + SingleThread = 1, + MultiThread = 2, + Serialized = 3 + } + + const string LibraryPath = "e_sqlite3"; + +#if !USE_CSHARP_SQLITE && !USE_WP8_NATIVE_SQLITE && !USE_SQLITEPCL_RAW + [DllImport(LibraryPath, EntryPoint = "sqlite3_threadsafe", CallingConvention = CallingConvention.Cdecl)] + public static extern int Threadsafe(); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_open", CallingConvention = CallingConvention.Cdecl)] + public static extern Result Open([MarshalAs(UnmanagedType.LPStr)] string filename, out IntPtr db); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_open_v2", CallingConvention = CallingConvention.Cdecl)] + public static extern Result Open([MarshalAs(UnmanagedType.LPStr)] string filename, out IntPtr db, int flags, IntPtr zvfs); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_open_v2", CallingConvention = CallingConvention.Cdecl)] + public static extern Result Open(byte[] filename, out IntPtr db, int flags, IntPtr zvfs); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_open16", CallingConvention = CallingConvention.Cdecl)] + public static extern Result Open16([MarshalAs(UnmanagedType.LPWStr)] string filename, out IntPtr db); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_enable_load_extension", CallingConvention = CallingConvention.Cdecl)] + public static extern Result EnableLoadExtension(IntPtr db, int onoff); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_close", CallingConvention = CallingConvention.Cdecl)] + public static extern Result Close(IntPtr db); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_close_v2", CallingConvention = CallingConvention.Cdecl)] + public static extern Result Close2(IntPtr db); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_initialize", CallingConvention = CallingConvention.Cdecl)] + public static extern Result Initialize(); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_shutdown", CallingConvention = CallingConvention.Cdecl)] + public static extern Result Shutdown(); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_config", CallingConvention = CallingConvention.Cdecl)] + public static extern Result Config(ConfigOption option); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_win32_set_directory", CallingConvention = CallingConvention.Cdecl, CharSet = CharSet.Unicode)] + public static extern int SetDirectory(uint directoryType, string directoryPath); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_busy_timeout", CallingConvention = CallingConvention.Cdecl)] + public static extern Result BusyTimeout(IntPtr db, int milliseconds); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_changes", CallingConvention = CallingConvention.Cdecl)] + public static extern int Changes(IntPtr db); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_prepare_v2", CallingConvention = CallingConvention.Cdecl)] + public static extern Result Prepare2(IntPtr db, [MarshalAs(UnmanagedType.LPStr)] string sql, int numBytes, out IntPtr stmt, IntPtr pzTail); + +#if NETFX_CORE + [DllImport (LibraryPath, EntryPoint = "sqlite3_prepare_v2", CallingConvention = CallingConvention.Cdecl)] + public static extern Result Prepare2 (IntPtr db, byte[] queryBytes, int numBytes, out IntPtr stmt, IntPtr pzTail); +#endif + + public static IntPtr Prepare2(IntPtr db, string query) + { + IntPtr stmt; +#if NETFX_CORE + byte[] queryBytes = System.Text.UTF8Encoding.UTF8.GetBytes (query); + var r = Prepare2 (db, queryBytes, queryBytes.Length, out stmt, IntPtr.Zero); +#else + var r = Prepare2(db, query, System.Text.UTF8Encoding.UTF8.GetByteCount(query), out stmt, IntPtr.Zero); +#endif + if (r != Result.OK) + { + throw SQLiteException.New(r, GetErrmsg(db)); + } + return stmt; + } + + [DllImport(LibraryPath, EntryPoint = "sqlite3_step", CallingConvention = CallingConvention.Cdecl)] + public static extern Result Step(IntPtr stmt); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_reset", CallingConvention = CallingConvention.Cdecl)] + public static extern Result Reset(IntPtr stmt); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_finalize", CallingConvention = CallingConvention.Cdecl)] + public static extern Result Finalize(IntPtr stmt); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_last_insert_rowid", CallingConvention = CallingConvention.Cdecl)] + public static extern long LastInsertRowid(IntPtr db); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_errmsg16", CallingConvention = CallingConvention.Cdecl)] + public static extern IntPtr Errmsg(IntPtr db); + + public static string GetErrmsg(IntPtr db) + { + return Marshal.PtrToStringUni(Errmsg(db)); + } + + [DllImport(LibraryPath, EntryPoint = "sqlite3_bind_parameter_index", CallingConvention = CallingConvention.Cdecl)] + public static extern int BindParameterIndex(IntPtr stmt, [MarshalAs(UnmanagedType.LPStr)] string name); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_bind_null", CallingConvention = CallingConvention.Cdecl)] + public static extern int BindNull(IntPtr stmt, int index); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_bind_int", CallingConvention = CallingConvention.Cdecl)] + public static extern int BindInt(IntPtr stmt, int index, int val); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_bind_int64", CallingConvention = CallingConvention.Cdecl)] + public static extern int BindInt64(IntPtr stmt, int index, long val); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_bind_double", CallingConvention = CallingConvention.Cdecl)] + public static extern int BindDouble(IntPtr stmt, int index, double val); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_bind_text16", CallingConvention = CallingConvention.Cdecl, CharSet = CharSet.Unicode)] + public static extern int BindText(IntPtr stmt, int index, [MarshalAs(UnmanagedType.LPWStr)] string val, int n, IntPtr free); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_bind_blob", CallingConvention = CallingConvention.Cdecl)] + public static extern int BindBlob(IntPtr stmt, int index, byte[] val, int n, IntPtr free); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_column_count", CallingConvention = CallingConvention.Cdecl)] + public static extern int ColumnCount(IntPtr stmt); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_column_name", CallingConvention = CallingConvention.Cdecl)] + public static extern IntPtr ColumnName(IntPtr stmt, int index); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_column_name16", CallingConvention = CallingConvention.Cdecl)] + static extern IntPtr ColumnName16Internal(IntPtr stmt, int index); + public static string ColumnName16(IntPtr stmt, int index) + { + return Marshal.PtrToStringUni(ColumnName16Internal(stmt, index)); + } + + [DllImport(LibraryPath, EntryPoint = "sqlite3_column_type", CallingConvention = CallingConvention.Cdecl)] + public static extern ColType ColumnType(IntPtr stmt, int index); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_column_int", CallingConvention = CallingConvention.Cdecl)] + public static extern int ColumnInt(IntPtr stmt, int index); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_column_int64", CallingConvention = CallingConvention.Cdecl)] + public static extern long ColumnInt64(IntPtr stmt, int index); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_column_double", CallingConvention = CallingConvention.Cdecl)] + public static extern double ColumnDouble(IntPtr stmt, int index); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_column_text", CallingConvention = CallingConvention.Cdecl)] + public static extern IntPtr ColumnText(IntPtr stmt, int index); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_column_text16", CallingConvention = CallingConvention.Cdecl)] + public static extern IntPtr ColumnText16(IntPtr stmt, int index); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_column_blob", CallingConvention = CallingConvention.Cdecl)] + public static extern IntPtr ColumnBlob(IntPtr stmt, int index); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_column_bytes", CallingConvention = CallingConvention.Cdecl)] + public static extern int ColumnBytes(IntPtr stmt, int index); + + public static string ColumnString(IntPtr stmt, int index) + { + return Marshal.PtrToStringUni(SQLite3.ColumnText16(stmt, index)); + } + + public static byte[] ColumnByteArray(IntPtr stmt, int index) + { + int length = ColumnBytes(stmt, index); + var result = new byte[length]; + if (length > 0) + Marshal.Copy(ColumnBlob(stmt, index), result, 0, length); + return result; + } + + [DllImport(LibraryPath, EntryPoint = "sqlite3_extended_errcode", CallingConvention = CallingConvention.Cdecl)] + public static extern ExtendedResult ExtendedErrCode(IntPtr db); + + [DllImport(LibraryPath, EntryPoint = "sqlite3_libversion_number", CallingConvention = CallingConvention.Cdecl)] + public static extern int LibVersionNumber(); +#else + public static Result Open (string filename, out Sqlite3DatabaseHandle db) + { + return (Result)Sqlite3.sqlite3_open (filename, out db); + } + + public static Result Open (string filename, out Sqlite3DatabaseHandle db, int flags, IntPtr zVfs) + { +#if USE_WP8_NATIVE_SQLITE + return (Result)Sqlite3.sqlite3_open_v2(filename, out db, flags, ""); +#else + return (Result)Sqlite3.sqlite3_open_v2 (filename, out db, flags, null); +#endif + } + + public static Result Close (Sqlite3DatabaseHandle db) + { + return (Result)Sqlite3.sqlite3_close (db); + } + + public static Result Close2 (Sqlite3DatabaseHandle db) + { + return (Result)Sqlite3.sqlite3_close_v2 (db); + } + + public static Result BusyTimeout (Sqlite3DatabaseHandle db, int milliseconds) + { + return (Result)Sqlite3.sqlite3_busy_timeout (db, milliseconds); + } + + public static int Changes (Sqlite3DatabaseHandle db) + { + return Sqlite3.sqlite3_changes (db); + } + + public static Sqlite3Statement Prepare2 (Sqlite3DatabaseHandle db, string query) + { + Sqlite3Statement stmt = default (Sqlite3Statement); +#if USE_WP8_NATIVE_SQLITE || USE_SQLITEPCL_RAW + var r = Sqlite3.sqlite3_prepare_v2 (db, query, out stmt); +#else + stmt = new Sqlite3Statement(); + var r = Sqlite3.sqlite3_prepare_v2(db, query, -1, ref stmt, 0); +#endif + if (r != 0) { + throw SQLiteException.New ((Result)r, GetErrmsg (db)); + } + return stmt; + } + + public static Result Step (Sqlite3Statement stmt) + { + return (Result)Sqlite3.sqlite3_step (stmt); + } + + public static Result Reset (Sqlite3Statement stmt) + { + return (Result)Sqlite3.sqlite3_reset (stmt); + } + + public static Result Finalize (Sqlite3Statement stmt) + { + return (Result)Sqlite3.sqlite3_finalize (stmt); + } + + public static long LastInsertRowid (Sqlite3DatabaseHandle db) + { + return Sqlite3.sqlite3_last_insert_rowid (db); + } + + public static string GetErrmsg (Sqlite3DatabaseHandle db) + { + return Sqlite3.sqlite3_errmsg (db); + } + + public static int BindParameterIndex (Sqlite3Statement stmt, string name) + { + return Sqlite3.sqlite3_bind_parameter_index (stmt, name); + } + + public static int BindNull (Sqlite3Statement stmt, int index) + { + return Sqlite3.sqlite3_bind_null (stmt, index); + } + + public static int BindInt (Sqlite3Statement stmt, int index, int val) + { + return Sqlite3.sqlite3_bind_int (stmt, index, val); + } + + public static int BindInt64 (Sqlite3Statement stmt, int index, long val) + { + return Sqlite3.sqlite3_bind_int64 (stmt, index, val); + } + + public static int BindDouble (Sqlite3Statement stmt, int index, double val) + { + return Sqlite3.sqlite3_bind_double (stmt, index, val); + } + + public static int BindText (Sqlite3Statement stmt, int index, string val, int n, IntPtr free) + { +#if USE_WP8_NATIVE_SQLITE + return Sqlite3.sqlite3_bind_text(stmt, index, val, n); +#elif USE_SQLITEPCL_RAW + return Sqlite3.sqlite3_bind_text (stmt, index, val); +#else + return Sqlite3.sqlite3_bind_text(stmt, index, val, n, null); +#endif + } + + public static int BindBlob (Sqlite3Statement stmt, int index, byte[] val, int n, IntPtr free) + { +#if USE_WP8_NATIVE_SQLITE + return Sqlite3.sqlite3_bind_blob(stmt, index, val, n); +#elif USE_SQLITEPCL_RAW + return Sqlite3.sqlite3_bind_blob (stmt, index, val); +#else + return Sqlite3.sqlite3_bind_blob(stmt, index, val, n, null); +#endif + } + + public static int ColumnCount (Sqlite3Statement stmt) + { + return Sqlite3.sqlite3_column_count (stmt); + } + + public static string ColumnName (Sqlite3Statement stmt, int index) + { + return Sqlite3.sqlite3_column_name (stmt, index); + } + + public static string ColumnName16 (Sqlite3Statement stmt, int index) + { + return Sqlite3.sqlite3_column_name (stmt, index); + } + + public static ColType ColumnType (Sqlite3Statement stmt, int index) + { + return (ColType)Sqlite3.sqlite3_column_type (stmt, index); + } + + public static int ColumnInt (Sqlite3Statement stmt, int index) + { + return Sqlite3.sqlite3_column_int (stmt, index); + } + + public static long ColumnInt64 (Sqlite3Statement stmt, int index) + { + return Sqlite3.sqlite3_column_int64 (stmt, index); + } + + public static double ColumnDouble (Sqlite3Statement stmt, int index) + { + return Sqlite3.sqlite3_column_double (stmt, index); + } + + public static string ColumnText (Sqlite3Statement stmt, int index) + { + return Sqlite3.sqlite3_column_text (stmt, index); + } + + public static string ColumnText16 (Sqlite3Statement stmt, int index) + { + return Sqlite3.sqlite3_column_text (stmt, index); + } + + public static byte[] ColumnBlob (Sqlite3Statement stmt, int index) + { + return Sqlite3.sqlite3_column_blob (stmt, index); + } + + public static int ColumnBytes (Sqlite3Statement stmt, int index) + { + return Sqlite3.sqlite3_column_bytes (stmt, index); + } + + public static string ColumnString (Sqlite3Statement stmt, int index) + { + return Sqlite3.sqlite3_column_text (stmt, index); + } + + public static byte[] ColumnByteArray (Sqlite3Statement stmt, int index) + { + int length = ColumnBytes (stmt, index); + if (length > 0) { + return ColumnBlob (stmt, index); + } + return new byte[0]; + } + + public static Result EnableLoadExtension (Sqlite3DatabaseHandle db, int onoff) + { + return (Result)Sqlite3.sqlite3_enable_load_extension (db, onoff); + } + + public static int LibVersionNumber () + { + return Sqlite3.sqlite3_libversion_number (); + } + + public static ExtendedResult ExtendedErrCode (Sqlite3DatabaseHandle db) + { + return (ExtendedResult)Sqlite3.sqlite3_extended_errcode (db); + } +#endif + + public enum ColType : int + { + Integer = 1, + Float = 2, + Text = 3, + Blob = 4, + Null = 5 + } + } +} \ No newline at end of file diff --git a/src/GitHub.App/sqlite-net/SQLiteAsync.cs b/src/GitHub.App/sqlite-net/SQLiteAsync.cs new file mode 100644 index 0000000000..77d4fb0388 --- /dev/null +++ b/src/GitHub.App/sqlite-net/SQLiteAsync.cs @@ -0,0 +1,1447 @@ +// +// Copyright (c) 2012-2017 Krueger Systems, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. +// + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Threading; +using System.Threading.Tasks; + +namespace SQLite +{ + /// + /// A pooled asynchronous connection to a SQLite database. + /// + public partial class SQLiteAsyncConnection + { + SQLiteConnectionString _connectionString; + SQLiteConnectionWithLock _fullMutexReadConnection; + readonly bool isFullMutex; + SQLiteOpenFlags _openFlags; + + /// + /// Constructs a new SQLiteAsyncConnection and opens a pooled SQLite database specified by databasePath. + /// + /// + /// Specifies the path to the database file. + /// + /// + /// Specifies whether to store DateTime properties as ticks (true) or strings (false). You + /// absolutely do want to store them as Ticks in all new projects. The value of false is + /// only here for backwards compatibility. There is a *significant* speed advantage, with no + /// down sides, when setting storeDateTimeAsTicks = true. + /// If you use DateTimeOffset properties, it will be always stored as ticks regardingless + /// the storeDateTimeAsTicks parameter. + /// + /// + /// Specifies the encryption key to use on the database. Should be a string or a byte[]. + /// + public SQLiteAsyncConnection(string databasePath, bool storeDateTimeAsTicks = true, object key = null) + : this(databasePath, SQLiteOpenFlags.FullMutex | SQLiteOpenFlags.ReadWrite | SQLiteOpenFlags.Create, storeDateTimeAsTicks, key: key) + { + } + + /// + /// Constructs a new SQLiteAsyncConnection and opens a pooled SQLite database specified by databasePath. + /// + /// + /// Specifies the path to the database file. + /// + /// + /// Flags controlling how the connection should be opened. + /// + /// + /// Specifies whether to store DateTime properties as ticks (true) or strings (false). You + /// absolutely do want to store them as Ticks in all new projects. The value of false is + /// only here for backwards compatibility. There is a *significant* speed advantage, with no + /// down sides, when setting storeDateTimeAsTicks = true. + /// If you use DateTimeOffset properties, it will be always stored as ticks regardingless + /// the storeDateTimeAsTicks parameter. + /// + /// + /// Specifies the encryption key to use on the database. Should be a string or a byte[]. + /// + public SQLiteAsyncConnection(string databasePath, SQLiteOpenFlags openFlags, bool storeDateTimeAsTicks = true, object key = null) + { + _openFlags = openFlags; + isFullMutex = _openFlags.HasFlag(SQLiteOpenFlags.FullMutex); + _connectionString = new SQLiteConnectionString(databasePath, storeDateTimeAsTicks, key); + if (isFullMutex) + _fullMutexReadConnection = new SQLiteConnectionWithLock(_connectionString, openFlags) { SkipLock = true }; + } + + /// + /// Gets the database path used by this connection. + /// + public string DatabasePath => GetConnection().DatabasePath; + + /// + /// Gets the SQLite library version number. 3007014 would be v3.7.14 + /// + public int LibVersionNumber => GetConnection().LibVersionNumber; + + /// + /// The amount of time to wait for a table to become unlocked. + /// + public TimeSpan GetBusyTimeout() + { + return GetConnection().BusyTimeout; + } + + /// + /// Sets the amount of time to wait for a table to become unlocked. + /// + public Task SetBusyTimeoutAsync(TimeSpan value) + { + return ReadAsync(conn => { + conn.BusyTimeout = value; + return null; + }); + } + + /// + /// Whether to store DateTime properties as ticks (true) or strings (false). + /// + public bool StoreDateTimeAsTicks => GetConnection().StoreDateTimeAsTicks; + + /// + /// Whether to writer queries to during execution. + /// + /// The tracer. + public bool Trace + { + get { return GetConnection().Trace; } + set { GetConnection().Trace = value; } + } + + /// + /// The delegate responsible for writing trace lines. + /// + /// The tracer. + public Action Tracer + { + get { return GetConnection().Tracer; } + set { GetConnection().Tracer = value; } + } + + /// + /// Whether Trace lines should be written that show the execution time of queries. + /// + public bool TimeExecution + { + get { return GetConnection().TimeExecution; } + set { GetConnection().TimeExecution = value; } + } + + /// + /// Returns the mappings from types to tables that the connection + /// currently understands. + /// + public IEnumerable TableMappings => GetConnection().TableMappings; + + /// + /// Closes all connections to all async databases. + /// You should *never* need to do this. + /// This is a blocking operation that will return when all connections + /// have been closed. + /// + public static void ResetPool() + { + SQLiteConnectionPool.Shared.Reset(); + } + + /// + /// Gets the pooled lockable connection used by this async connection. + /// You should never need to use this. This is provided only to add additional + /// functionality to SQLite-net. If you use this connection, you must use + /// the Lock method on it while using it. + /// + public SQLiteConnectionWithLock GetConnection() + { + return SQLiteConnectionPool.Shared.GetConnection(_connectionString, _openFlags); + } + + /// + /// Closes any pooled connections used by the database. + /// + public Task CloseAsync() + { + return Task.Factory.StartNew(() => { + SQLiteConnectionPool.Shared.CloseConnection(_connectionString, _openFlags); + }, CancellationToken.None, TaskCreationOptions.DenyChildAttach, TaskScheduler.Default); + } + + Task ReadAsync(Func read) + { + return Task.Factory.StartNew(() => { + var conn = isFullMutex ? _fullMutexReadConnection : GetConnection(); + using (conn.Lock()) + { + return read(conn); + } + }, CancellationToken.None, TaskCreationOptions.DenyChildAttach, TaskScheduler.Default); + } + + Task WriteAsync(Func write) + { + return Task.Factory.StartNew(() => { + var conn = GetConnection(); + using (conn.Lock()) + { + return write(conn); + } + }, CancellationToken.None, TaskCreationOptions.DenyChildAttach, TaskScheduler.Default); + } + + /// + /// Enable or disable extension loading. + /// + public Task EnableLoadExtensionAsync(bool enabled) + { + return WriteAsync(conn => { + conn.EnableLoadExtension(enabled); + return null; + }); + } + + /// + /// Executes a "create table if not exists" on the database. It also + /// creates any specified indexes on the columns of the table. It uses + /// a schema automatically generated from the specified type. You can + /// later access this schema by calling GetMapping. + /// + /// + /// Whether the table was created or migrated. + /// + public Task CreateTableAsync(CreateFlags createFlags = CreateFlags.None) + where T : new() + { + return WriteAsync(conn => conn.CreateTable(createFlags)); + } + + /// + /// Executes a "create table if not exists" on the database. It also + /// creates any specified indexes on the columns of the table. It uses + /// a schema automatically generated from the specified type. You can + /// later access this schema by calling GetMapping. + /// + /// Type to reflect to a database table. + /// Optional flags allowing implicit PK and indexes based on naming conventions. + /// + /// Whether the table was created or migrated. + /// + public Task CreateTableAsync(Type ty, CreateFlags createFlags = CreateFlags.None) + { + return WriteAsync(conn => conn.CreateTable(ty, createFlags)); + } + + /// + /// Executes a "create table if not exists" on the database for each type. It also + /// creates any specified indexes on the columns of the table. It uses + /// a schema automatically generated from the specified type. You can + /// later access this schema by calling GetMapping. + /// + /// + /// Whether the table was created or migrated for each type. + /// + public Task CreateTablesAsync(CreateFlags createFlags = CreateFlags.None) + where T : new() + where T2 : new() + { + return CreateTablesAsync(createFlags, typeof(T), typeof(T2)); + } + + /// + /// Executes a "create table if not exists" on the database for each type. It also + /// creates any specified indexes on the columns of the table. It uses + /// a schema automatically generated from the specified type. You can + /// later access this schema by calling GetMapping. + /// + /// + /// Whether the table was created or migrated for each type. + /// + public Task CreateTablesAsync(CreateFlags createFlags = CreateFlags.None) + where T : new() + where T2 : new() + where T3 : new() + { + return CreateTablesAsync(createFlags, typeof(T), typeof(T2), typeof(T3)); + } + + /// + /// Executes a "create table if not exists" on the database for each type. It also + /// creates any specified indexes on the columns of the table. It uses + /// a schema automatically generated from the specified type. You can + /// later access this schema by calling GetMapping. + /// + /// + /// Whether the table was created or migrated for each type. + /// + public Task CreateTablesAsync(CreateFlags createFlags = CreateFlags.None) + where T : new() + where T2 : new() + where T3 : new() + where T4 : new() + { + return CreateTablesAsync(createFlags, typeof(T), typeof(T2), typeof(T3), typeof(T4)); + } + + /// + /// Executes a "create table if not exists" on the database for each type. It also + /// creates any specified indexes on the columns of the table. It uses + /// a schema automatically generated from the specified type. You can + /// later access this schema by calling GetMapping. + /// + /// + /// Whether the table was created or migrated for each type. + /// + public Task CreateTablesAsync(CreateFlags createFlags = CreateFlags.None) + where T : new() + where T2 : new() + where T3 : new() + where T4 : new() + where T5 : new() + { + return CreateTablesAsync(createFlags, typeof(T), typeof(T2), typeof(T3), typeof(T4), typeof(T5)); + } + + /// + /// Executes a "create table if not exists" on the database for each type. It also + /// creates any specified indexes on the columns of the table. It uses + /// a schema automatically generated from the specified type. You can + /// later access this schema by calling GetMapping. + /// + /// + /// Whether the table was created or migrated for each type. + /// + public Task CreateTablesAsync(CreateFlags createFlags = CreateFlags.None, params Type[] types) + { + return WriteAsync(conn => conn.CreateTables(createFlags, types)); + } + + /// + /// Executes a "drop table" on the database. This is non-recoverable. + /// + public Task DropTableAsync() + where T : new() + { + return WriteAsync(conn => conn.DropTable()); + } + + /// + /// Executes a "drop table" on the database. This is non-recoverable. + /// + /// + /// The TableMapping used to identify the table. + /// + public Task DropTableAsync(TableMapping map) + { + return WriteAsync(conn => conn.DropTable(map)); + } + + /// + /// Creates an index for the specified table and column. + /// + /// Name of the database table + /// Name of the column to index + /// Whether the index should be unique + public Task CreateIndexAsync(string tableName, string columnName, bool unique = false) + { + return WriteAsync(conn => conn.CreateIndex(tableName, columnName, unique)); + } + + /// + /// Creates an index for the specified table and column. + /// + /// Name of the index to create + /// Name of the database table + /// Name of the column to index + /// Whether the index should be unique + public Task CreateIndexAsync(string indexName, string tableName, string columnName, bool unique = false) + { + return WriteAsync(conn => conn.CreateIndex(indexName, tableName, columnName, unique)); + } + + /// + /// Creates an index for the specified table and columns. + /// + /// Name of the database table + /// An array of column names to index + /// Whether the index should be unique + public Task CreateIndexAsync(string tableName, string[] columnNames, bool unique = false) + { + return WriteAsync(conn => conn.CreateIndex(tableName, columnNames, unique)); + } + + /// + /// Creates an index for the specified table and columns. + /// + /// Name of the index to create + /// Name of the database table + /// An array of column names to index + /// Whether the index should be unique + public Task CreateIndexAsync(string indexName, string tableName, string[] columnNames, bool unique = false) + { + return WriteAsync(conn => conn.CreateIndex(indexName, tableName, columnNames, unique)); + } + + /// + /// Creates an index for the specified object property. + /// e.g. CreateIndex<Client>(c => c.Name); + /// + /// Type to reflect to a database table. + /// Property to index + /// Whether the index should be unique + public Task CreateIndexAsync(Expression> property, bool unique = false) + { + return WriteAsync(conn => conn.CreateIndex(property, unique)); + } + + /// + /// Inserts the given object and retrieves its + /// auto incremented primary key if it has one. + /// + /// + /// The object to insert. + /// + /// + /// The number of rows added to the table. + /// + public Task InsertAsync(object obj) + { + return WriteAsync(conn => conn.Insert(obj)); + } + + /// + /// Inserts the given object (and updates its + /// auto incremented primary key if it has one). + /// The return value is the number of rows added to the table. + /// + /// + /// The object to insert. + /// + /// + /// The type of object to insert. + /// + /// + /// The number of rows added to the table. + /// + public Task InsertAsync(object obj, Type objType) + { + return WriteAsync(conn => conn.Insert(obj, objType)); + } + + /// + /// Inserts the given object (and updates its + /// auto incremented primary key if it has one). + /// The return value is the number of rows added to the table. + /// + /// + /// The object to insert. + /// + /// + /// Literal SQL code that gets placed into the command. INSERT {extra} INTO ... + /// + /// + /// The number of rows added to the table. + /// + public Task InsertAsync(object obj, string extra) + { + return WriteAsync(conn => conn.Insert(obj, extra)); + } + + /// + /// Inserts the given object (and updates its + /// auto incremented primary key if it has one). + /// The return value is the number of rows added to the table. + /// + /// + /// The object to insert. + /// + /// + /// Literal SQL code that gets placed into the command. INSERT {extra} INTO ... + /// + /// + /// The type of object to insert. + /// + /// + /// The number of rows added to the table. + /// + public Task InsertAsync(object obj, string extra, Type objType) + { + return WriteAsync(conn => conn.Insert(obj, extra, objType)); + } + + /// + /// Inserts the given object (and updates its + /// auto incremented primary key if it has one). + /// The return value is the number of rows added to the table. + /// If a UNIQUE constraint violation occurs with + /// some pre-existing object, this function deletes + /// the old object. + /// + /// + /// The object to insert. + /// + /// + /// The number of rows modified. + /// + public Task InsertOrReplaceAsync(object obj) + { + return WriteAsync(conn => conn.InsertOrReplace(obj)); + } + + /// + /// Inserts the given object (and updates its + /// auto incremented primary key if it has one). + /// The return value is the number of rows added to the table. + /// If a UNIQUE constraint violation occurs with + /// some pre-existing object, this function deletes + /// the old object. + /// + /// + /// The object to insert. + /// + /// + /// The type of object to insert. + /// + /// + /// The number of rows modified. + /// + public Task InsertOrReplaceAsync(object obj, Type objType) + { + return WriteAsync(conn => conn.InsertOrReplace(obj, objType)); + } + + /// + /// Updates all of the columns of a table using the specified object + /// except for its primary key. + /// The object is required to have a primary key. + /// + /// + /// The object to update. It must have a primary key designated using the PrimaryKeyAttribute. + /// + /// + /// The number of rows updated. + /// + public Task UpdateAsync(object obj) + { + return WriteAsync(conn => conn.Update(obj)); + } + + /// + /// Updates all of the columns of a table using the specified object + /// except for its primary key. + /// The object is required to have a primary key. + /// + /// + /// The object to update. It must have a primary key designated using the PrimaryKeyAttribute. + /// + /// + /// The type of object to insert. + /// + /// + /// The number of rows updated. + /// + public Task UpdateAsync(object obj, Type objType) + { + return WriteAsync(conn => conn.Update(obj, objType)); + } + + /// + /// Updates all specified objects. + /// + /// + /// An of the objects to insert. + /// + /// + /// A boolean indicating if the inserts should be wrapped in a transaction + /// + /// + /// The number of rows modified. + /// + public Task UpdateAllAsync(IEnumerable objects, bool runInTransaction = true) + { + return WriteAsync(conn => conn.UpdateAll(objects, runInTransaction)); + } + + /// + /// Deletes the given object from the database using its primary key. + /// + /// + /// The object to delete. It must have a primary key designated using the PrimaryKeyAttribute. + /// + /// + /// The number of rows deleted. + /// + public Task DeleteAsync(object objectToDelete) + { + return WriteAsync(conn => conn.Delete(objectToDelete)); + } + + /// + /// Deletes the object with the specified primary key. + /// + /// + /// The primary key of the object to delete. + /// + /// + /// The number of objects deleted. + /// + /// + /// The type of object. + /// + public Task DeleteAsync(object primaryKey) + { + return WriteAsync(conn => conn.Delete(primaryKey)); + } + + /// + /// Deletes the object with the specified primary key. + /// + /// + /// The primary key of the object to delete. + /// + /// + /// The TableMapping used to identify the table. + /// + /// + /// The number of objects deleted. + /// + public Task DeleteAsync(object primaryKey, TableMapping map) + { + return WriteAsync(conn => conn.Delete(primaryKey, map)); + } + + /// + /// Deletes all the objects from the specified table. + /// WARNING WARNING: Let me repeat. It deletes ALL the objects from the + /// specified table. Do you really want to do that? + /// + /// + /// The number of objects deleted. + /// + /// + /// The type of objects to delete. + /// + public Task DeleteAllAsync() + { + return WriteAsync(conn => conn.DeleteAll()); + } + + /// + /// Deletes all the objects from the specified table. + /// WARNING WARNING: Let me repeat. It deletes ALL the objects from the + /// specified table. Do you really want to do that? + /// + /// + /// The TableMapping used to identify the table. + /// + /// + /// The number of objects deleted. + /// + public Task DeleteAllAsync(TableMapping map) + { + return WriteAsync(conn => conn.DeleteAll(map)); + } + + /// + /// Attempts to retrieve an object with the given primary key from the table + /// associated with the specified type. Use of this method requires that + /// the given type have a designated PrimaryKey (using the PrimaryKeyAttribute). + /// + /// + /// The primary key. + /// + /// + /// The object with the given primary key. Throws a not found exception + /// if the object is not found. + /// + public Task GetAsync(object pk) + where T : new() + { + return ReadAsync(conn => conn.Get(pk)); + } + + /// + /// Attempts to retrieve an object with the given primary key from the table + /// associated with the specified type. Use of this method requires that + /// the given type have a designated PrimaryKey (using the PrimaryKeyAttribute). + /// + /// + /// The primary key. + /// + /// + /// The TableMapping used to identify the table. + /// + /// + /// The object with the given primary key. Throws a not found exception + /// if the object is not found. + /// + public Task GetAsync(object pk, TableMapping map) + { + return ReadAsync(conn => conn.Get(pk, map)); + } + + /// + /// Attempts to retrieve the first object that matches the predicate from the table + /// associated with the specified type. + /// + /// + /// A predicate for which object to find. + /// + /// + /// The object that matches the given predicate. Throws a not found exception + /// if the object is not found. + /// + public Task GetAsync(Expression> predicate) + where T : new() + { + return ReadAsync(conn => conn.Get(predicate)); + } + + /// + /// Attempts to retrieve an object with the given primary key from the table + /// associated with the specified type. Use of this method requires that + /// the given type have a designated PrimaryKey (using the PrimaryKeyAttribute). + /// + /// + /// The primary key. + /// + /// + /// The object with the given primary key or null + /// if the object is not found. + /// + public Task FindAsync(object pk) + where T : new() + { + return ReadAsync(conn => conn.Find(pk)); + } + + /// + /// Attempts to retrieve an object with the given primary key from the table + /// associated with the specified type. Use of this method requires that + /// the given type have a designated PrimaryKey (using the PrimaryKeyAttribute). + /// + /// + /// The primary key. + /// + /// + /// The TableMapping used to identify the table. + /// + /// + /// The object with the given primary key or null + /// if the object is not found. + /// + public Task FindAsync(object pk, TableMapping map) + { + return ReadAsync(conn => conn.Find(pk, map)); + } + + /// + /// Attempts to retrieve the first object that matches the predicate from the table + /// associated with the specified type. + /// + /// + /// A predicate for which object to find. + /// + /// + /// The object that matches the given predicate or null + /// if the object is not found. + /// + public Task FindAsync(Expression> predicate) + where T : new() + { + return ReadAsync(conn => conn.Find(predicate)); + } + + /// + /// Attempts to retrieve the first object that matches the query from the table + /// associated with the specified type. + /// + /// + /// The fully escaped SQL. + /// + /// + /// Arguments to substitute for the occurences of '?' in the query. + /// + /// + /// The object that matches the given predicate or null + /// if the object is not found. + /// + public Task FindWithQueryAsync(string query, params object[] args) + where T : new() + { + return ReadAsync(conn => conn.FindWithQuery(query, args)); + } + + /// + /// Attempts to retrieve the first object that matches the query from the table + /// associated with the specified type. + /// + /// + /// The TableMapping used to identify the table. + /// + /// + /// The fully escaped SQL. + /// + /// + /// Arguments to substitute for the occurences of '?' in the query. + /// + /// + /// The object that matches the given predicate or null + /// if the object is not found. + /// + public Task FindWithQueryAsync(TableMapping map, string query, params object[] args) + { + return ReadAsync(conn => conn.FindWithQuery(map, query, args)); + } + + /// + /// Retrieves the mapping that is automatically generated for the given type. + /// + /// + /// The type whose mapping to the database is returned. + /// + /// + /// Optional flags allowing implicit PK and indexes based on naming conventions + /// + /// + /// The mapping represents the schema of the columns of the database and contains + /// methods to set and get properties of objects. + /// + public Task GetMappingAsync(Type type, CreateFlags createFlags = CreateFlags.None) + { + return ReadAsync(conn => conn.GetMapping(type, createFlags)); + } + + /// + /// Retrieves the mapping that is automatically generated for the given type. + /// + /// + /// Optional flags allowing implicit PK and indexes based on naming conventions + /// + /// + /// The mapping represents the schema of the columns of the database and contains + /// methods to set and get properties of objects. + /// + public Task GetMappingAsync(CreateFlags createFlags = CreateFlags.None) + where T : new() + { + return ReadAsync(conn => conn.GetMapping(createFlags)); + } + + /// + /// Query the built-in sqlite table_info table for a specific tables columns. + /// + /// The columns contains in the table. + /// Table name. + public Task> GetTableInfoAsync(string tableName) + { + return ReadAsync(conn => conn.GetTableInfo(tableName)); + } + + /// + /// Creates a SQLiteCommand given the command text (SQL) with arguments. Place a '?' + /// in the command text for each of the arguments and then executes that command. + /// Use this method instead of Query when you don't expect rows back. Such cases include + /// INSERTs, UPDATEs, and DELETEs. + /// You can set the Trace or TimeExecution properties of the connection + /// to profile execution. + /// + /// + /// The fully escaped SQL. + /// + /// + /// Arguments to substitute for the occurences of '?' in the query. + /// + /// + /// The number of rows modified in the database as a result of this execution. + /// + public Task ExecuteAsync(string query, params object[] args) + { + return WriteAsync(conn => conn.Execute(query, args)); + } + + /// + /// Inserts all specified objects. + /// + /// + /// An of the objects to insert. + /// + /// A boolean indicating if the inserts should be wrapped in a transaction. + /// + /// + /// The number of rows added to the table. + /// + public Task InsertAllAsync(IEnumerable objects, bool runInTransaction = true) + { + return WriteAsync(conn => conn.InsertAll(objects, runInTransaction)); + } + + /// + /// Inserts all specified objects. + /// + /// + /// An of the objects to insert. + /// + /// + /// Literal SQL code that gets placed into the command. INSERT {extra} INTO ... + /// + /// + /// A boolean indicating if the inserts should be wrapped in a transaction. + /// + /// + /// The number of rows added to the table. + /// + public Task InsertAllAsync(IEnumerable objects, string extra, bool runInTransaction = true) + { + return WriteAsync(conn => conn.InsertAll(objects, extra, runInTransaction)); + } + + /// + /// Inserts all specified objects. + /// + /// + /// An of the objects to insert. + /// + /// + /// The type of object to insert. + /// + /// + /// A boolean indicating if the inserts should be wrapped in a transaction. + /// + /// + /// The number of rows added to the table. + /// + public Task InsertAllAsync(IEnumerable objects, Type objType, bool runInTransaction = true) + { + return WriteAsync(conn => conn.InsertAll(objects, objType, runInTransaction)); + } + + /// + /// Executes within a (possibly nested) transaction by wrapping it in a SAVEPOINT. If an + /// exception occurs the whole transaction is rolled back, not just the current savepoint. The exception + /// is rethrown. + /// + /// + /// The to perform within a transaction. can contain any number + /// of operations on the connection but should never call or + /// . + /// + public Task RunInTransactionAsync(Action action) + { + return WriteAsync(conn => { + conn.BeginTransaction(); + try + { + action(conn); + conn.Commit(); + return null; + } + catch (Exception) + { + conn.Rollback(); + throw; + } + }); + } + + /// + /// Returns a queryable interface to the table represented by the given type. + /// + /// + /// A queryable object that is able to translate Where, OrderBy, and Take + /// queries into native SQL. + /// + public AsyncTableQuery Table() + where T : new() + { + // + // This isn't async as the underlying connection doesn't go out to the database + // until the query is performed. The Async methods are on the query iteself. + // + var conn = GetConnection(); + return new AsyncTableQuery(conn.Table()); + } + + /// + /// Creates a SQLiteCommand given the command text (SQL) with arguments. Place a '?' + /// in the command text for each of the arguments and then executes that command. + /// Use this method when return primitive values. + /// You can set the Trace or TimeExecution properties of the connection + /// to profile execution. + /// + /// + /// The fully escaped SQL. + /// + /// + /// Arguments to substitute for the occurences of '?' in the query. + /// + /// + /// The number of rows modified in the database as a result of this execution. + /// + public Task ExecuteScalarAsync(string query, params object[] args) + { + return WriteAsync(conn => { + var command = conn.CreateCommand(query, args); + return command.ExecuteScalar(); + }); + } + + /// + /// Creates a SQLiteCommand given the command text (SQL) with arguments. Place a '?' + /// in the command text for each of the arguments and then executes that command. + /// It returns each row of the result using the mapping automatically generated for + /// the given type. + /// + /// + /// The fully escaped SQL. + /// + /// + /// Arguments to substitute for the occurences of '?' in the query. + /// + /// + /// An enumerable with one result for each row returned by the query. + /// + public Task> QueryAsync(string query, params object[] args) + where T : new() + { + return ReadAsync(conn => conn.Query(query, args)); + } + + /// + /// Creates a SQLiteCommand given the command text (SQL) with arguments. Place a '?' + /// in the command text for each of the arguments and then executes that command. + /// It returns each row of the result using the specified mapping. This function is + /// only used by libraries in order to query the database via introspection. It is + /// normally not used. + /// + /// + /// A to use to convert the resulting rows + /// into objects. + /// + /// + /// The fully escaped SQL. + /// + /// + /// Arguments to substitute for the occurences of '?' in the query. + /// + /// + /// An enumerable with one result for each row returned by the query. + /// + public Task> QueryAsync(TableMapping map, string query, params object[] args) + { + return ReadAsync(conn => conn.Query(map, query, args)); + } + + /// + /// Creates a SQLiteCommand given the command text (SQL) with arguments. Place a '?' + /// in the command text for each of the arguments and then executes that command. + /// It returns each row of the result using the mapping automatically generated for + /// the given type. + /// + /// + /// The fully escaped SQL. + /// + /// + /// Arguments to substitute for the occurences of '?' in the query. + /// + /// + /// An enumerable with one result for each row returned by the query. + /// The enumerator will call sqlite3_step on each call to MoveNext, so the database + /// connection must remain open for the lifetime of the enumerator. + /// + public Task> DeferredQueryAsync(string query, params object[] args) + where T : new() + { + return ReadAsync(conn => (IEnumerable)conn.DeferredQuery(query, args).ToList()); + } + + /// + /// Creates a SQLiteCommand given the command text (SQL) with arguments. Place a '?' + /// in the command text for each of the arguments and then executes that command. + /// It returns each row of the result using the specified mapping. This function is + /// only used by libraries in order to query the database via introspection. It is + /// normally not used. + /// + /// + /// A to use to convert the resulting rows + /// into objects. + /// + /// + /// The fully escaped SQL. + /// + /// + /// Arguments to substitute for the occurences of '?' in the query. + /// + /// + /// An enumerable with one result for each row returned by the query. + /// The enumerator will call sqlite3_step on each call to MoveNext, so the database + /// connection must remain open for the lifetime of the enumerator. + /// + public Task> DeferredQueryAsync(TableMapping map, string query, params object[] args) + { + return ReadAsync(conn => (IEnumerable)conn.DeferredQuery(map, query, args).ToList()); + } + } + + // + // TODO: Bind to AsyncConnection.GetConnection instead so that delayed + // execution can still work after a Pool.Reset. + // + + /// + /// Query to an asynchronous database connection. + /// + public class AsyncTableQuery + where T : new() + { + TableQuery _innerQuery; + + /// + /// Creates a new async query that uses given the synchronous query. + /// + public AsyncTableQuery(TableQuery innerQuery) + { + _innerQuery = innerQuery; + } + + Task ReadAsync(Func read) + { + return Task.Factory.StartNew(() => { + var conn = (SQLiteConnectionWithLock)_innerQuery.Connection; + using (conn.Lock()) + { + return read(conn); + } + }, CancellationToken.None, TaskCreationOptions.DenyChildAttach, TaskScheduler.Default); + } + + Task WriteAsync(Func write) + { + return Task.Factory.StartNew(() => { + var conn = (SQLiteConnectionWithLock)_innerQuery.Connection; + using (conn.Lock()) + { + return write(conn); + } + }, CancellationToken.None, TaskCreationOptions.DenyChildAttach, TaskScheduler.Default); + } + + /// + /// Filters the query based on a predicate. + /// + public AsyncTableQuery Where(Expression> predExpr) + { + return new AsyncTableQuery(_innerQuery.Where(predExpr)); + } + + /// + /// Skips a given number of elements from the query and then yields the remainder. + /// + public AsyncTableQuery Skip(int n) + { + return new AsyncTableQuery(_innerQuery.Skip(n)); + } + + /// + /// Yields a given number of elements from the query and then skips the remainder. + /// + public AsyncTableQuery Take(int n) + { + return new AsyncTableQuery(_innerQuery.Take(n)); + } + + /// + /// Order the query results according to a key. + /// + public AsyncTableQuery OrderBy(Expression> orderExpr) + { + return new AsyncTableQuery(_innerQuery.OrderBy(orderExpr)); + } + + /// + /// Order the query results according to a key. + /// + public AsyncTableQuery OrderByDescending(Expression> orderExpr) + { + return new AsyncTableQuery(_innerQuery.OrderByDescending(orderExpr)); + } + + /// + /// Order the query results according to a key. + /// + public AsyncTableQuery ThenBy(Expression> orderExpr) + { + return new AsyncTableQuery(_innerQuery.ThenBy(orderExpr)); + } + + /// + /// Order the query results according to a key. + /// + public AsyncTableQuery ThenByDescending(Expression> orderExpr) + { + return new AsyncTableQuery(_innerQuery.ThenByDescending(orderExpr)); + } + + /// + /// Queries the database and returns the results as a List. + /// + public Task> ToListAsync() + { + return ReadAsync(conn => _innerQuery.ToList()); + } + + /// + /// Queries the database and returns the results as an array. + /// + public Task ToArrayAsync() + { + return ReadAsync(conn => _innerQuery.ToArray()); + } + + /// + /// Execute SELECT COUNT(*) on the query + /// + public Task CountAsync() + { + return ReadAsync(conn => _innerQuery.Count()); + } + + /// + /// Execute SELECT COUNT(*) on the query with an additional WHERE clause. + /// + public Task CountAsync(Expression> predExpr) + { + return ReadAsync(conn => _innerQuery.Count(predExpr)); + } + + /// + /// Returns the element at a given index + /// + public Task ElementAtAsync(int index) + { + return ReadAsync(conn => _innerQuery.ElementAt(index)); + } + + /// + /// Returns the first element of this query. + /// + public Task FirstAsync() + { + return ReadAsync(conn => _innerQuery.First()); + } + + /// + /// Returns the first element of this query, or null if no element is found. + /// + public Task FirstOrDefaultAsync() + { + return ReadAsync(conn => _innerQuery.FirstOrDefault()); + } + + /// + /// Returns the first element of this query that matches the predicate. + /// + public Task FirstAsync(Expression> predExpr) + { + return ReadAsync(conn => _innerQuery.First(predExpr)); + } + + /// + /// Returns the first element of this query that matches the predicate. + /// + public Task FirstOrDefaultAsync(Expression> predExpr) + { + return ReadAsync(conn => _innerQuery.FirstOrDefault(predExpr)); + } + + /// + /// Delete all the rows that match this query and the given predicate. + /// + public Task DeleteAsync(Expression> predExpr) + { + return WriteAsync(conn => _innerQuery.Delete(predExpr)); + } + + /// + /// Delete all the rows that match this query. + /// + public Task DeleteAsync() + { + return WriteAsync(conn => _innerQuery.Delete()); + } + } + + class SQLiteConnectionPool + { + class Entry + { + public SQLiteConnectionString ConnectionString { get; private set; } + public SQLiteConnectionWithLock Connection { get; private set; } + + public Entry(SQLiteConnectionString connectionString, SQLiteOpenFlags openFlags) + { + ConnectionString = connectionString; + Connection = new SQLiteConnectionWithLock(connectionString, openFlags); + } + + public void Close() + { + if (Connection == null) + return; + using (var l = Connection.Lock()) + { + Connection.Dispose(); + } + Connection = null; + } + } + + readonly Dictionary _entries = new Dictionary(); + readonly object _entriesLock = new object(); + + static readonly SQLiteConnectionPool _shared = new SQLiteConnectionPool(); + + /// + /// Gets the singleton instance of the connection tool. + /// + public static SQLiteConnectionPool Shared + { + get + { + return _shared; + } + } + + public SQLiteConnectionWithLock GetConnection(SQLiteConnectionString connectionString, SQLiteOpenFlags openFlags) + { + lock (_entriesLock) + { + Entry entry; + string key = connectionString.ConnectionString; + + if (!_entries.TryGetValue(key, out entry)) + { + entry = new Entry(connectionString, openFlags); + _entries[key] = entry; + } + + return entry.Connection; + } + } + + public void CloseConnection(SQLiteConnectionString connectionString, SQLiteOpenFlags openFlags) + { + var key = connectionString.ConnectionString; + + Entry entry; + lock (_entriesLock) + { + if (_entries.TryGetValue(key, out entry)) + { + _entries.Remove(key); + } + } + + entry.Close(); + } + + /// + /// Closes all connections managed by this pool. + /// + public void Reset() + { + List entries; + lock (_entriesLock) + { + entries = new List(_entries.Values); + _entries.Clear(); + } + + foreach (var e in entries) + { + e.Close(); + } + } + } + + /// + /// This is a normal connection except it contains a Lock method that + /// can be used to serialize access to the database. + /// + public class SQLiteConnectionWithLock : SQLiteConnection + { + readonly object _lockPoint = new object(); + + /// + /// Initializes a new instance of the class. + /// + /// Connection string containing the DatabasePath. + /// Open flags. + public SQLiteConnectionWithLock(SQLiteConnectionString connectionString, SQLiteOpenFlags openFlags) + : base(connectionString.DatabasePath, openFlags, connectionString.StoreDateTimeAsTicks, key: connectionString.Key) + { + } + + /// + /// Gets or sets a value indicating whether this skip lock. + /// + /// true if skip lock; otherwise, false. + public bool SkipLock { get; set; } + + /// + /// Lock the database to serialize access to it. To unlock it, call Dispose + /// on the returned object. + /// + /// The lock. + public IDisposable Lock() + { + return SkipLock ? (IDisposable)new FakeLockWrapper() : new LockWrapper(_lockPoint); + } + + class LockWrapper : IDisposable + { + object _lockPoint; + + public LockWrapper(object lockPoint) + { + _lockPoint = lockPoint; + Monitor.Enter(_lockPoint); + } + + public void Dispose() + { + Monitor.Exit(_lockPoint); + } + } + class FakeLockWrapper : IDisposable + { + public void Dispose() + { + } + } + } +} diff --git a/src/GitHub.Exports.Reactive/GitHub.Exports.Reactive.csproj b/src/GitHub.Exports.Reactive/GitHub.Exports.Reactive.csproj index a07af76b63..acf872a864 100644 --- a/src/GitHub.Exports.Reactive/GitHub.Exports.Reactive.csproj +++ b/src/GitHub.Exports.Reactive/GitHub.Exports.Reactive.csproj @@ -27,5 +27,6 @@ + diff --git a/src/GitHub.InlineReviews/Services/IInlineCommentPeekService.cs b/src/GitHub.Exports.Reactive/Services/IInlineCommentPeekService.cs similarity index 69% rename from src/GitHub.InlineReviews/Services/IInlineCommentPeekService.cs rename to src/GitHub.Exports.Reactive/Services/IInlineCommentPeekService.cs index 5351d64a2d..219f657241 100644 --- a/src/GitHub.InlineReviews/Services/IInlineCommentPeekService.cs +++ b/src/GitHub.Exports.Reactive/Services/IInlineCommentPeekService.cs @@ -1,10 +1,10 @@ using System; -using GitHub.InlineReviews.Tags; +using GitHub.Models; using Microsoft.VisualStudio.Language.Intellisense; using Microsoft.VisualStudio.Text; using Microsoft.VisualStudio.Text.Editor; -namespace GitHub.InlineReviews.Services +namespace GitHub.Services { /// /// Shows inline comments in a peek view. @@ -29,17 +29,10 @@ public interface IInlineCommentPeekService void Hide(ITextView textView); /// - /// Shows the peek view for a . + /// Shows the peek view for on an . /// /// The text view. /// The tag. - ITrackingPoint Show(ITextView textView, ShowInlineCommentTag tag); - - /// - /// Shows the peek view for an . - /// - /// The text view. - /// The tag. - ITrackingPoint Show(ITextView textView, AddInlineCommentTag tag); + ITrackingPoint Show(ITextView textView, DiffSide side, int lineNumber); } } \ No newline at end of file diff --git a/src/GitHub.Exports.Reactive/ViewModels/IClosable.cs b/src/GitHub.Exports.Reactive/ViewModels/IClosable.cs new file mode 100644 index 0000000000..ac6ab171c3 --- /dev/null +++ b/src/GitHub.Exports.Reactive/ViewModels/IClosable.cs @@ -0,0 +1,16 @@ +using System; +using System.Reactive; + +namespace GitHub.ViewModels +{ + /// + /// Represents an entity that can be closed. + /// + public interface IClosable + { + /// + /// Gets an observable that is fired when the entity is closed. + /// + IObservable Closed { get; } + } +} diff --git a/src/GitHub.Exports.Reactive/ViewModels/ICommentThreadViewModel.cs b/src/GitHub.Exports.Reactive/ViewModels/ICommentThreadViewModel.cs index a902e9dbb8..31c24d15a5 100644 --- a/src/GitHub.Exports.Reactive/ViewModels/ICommentThreadViewModel.cs +++ b/src/GitHub.Exports.Reactive/ViewModels/ICommentThreadViewModel.cs @@ -22,16 +22,19 @@ public interface ICommentThreadViewModel : IViewModel /// /// Called by a comment in the thread to post itself as a new comment to the API. /// - Task PostComment(string body); + /// The comment to post. + Task PostComment(ICommentViewModel comment); /// /// Called by a comment in the thread to post itself as an edit to a comment to the API. /// - Task EditComment(string id, string body); + /// The comment to edit. + Task EditComment(ICommentViewModel comment); /// /// Called by a comment in the thread to delete the comment on the API. /// - Task DeleteComment(int pullRequestId, int commentId); + /// The comment to delete. + Task DeleteComment(ICommentViewModel comment); } } diff --git a/src/GitHub.Exports.Reactive/ViewModels/IPullRequestReviewCommentViewModel.cs b/src/GitHub.Exports.Reactive/ViewModels/IPullRequestReviewCommentViewModel.cs index 175c09f3f0..b05ab5c0f7 100644 --- a/src/GitHub.Exports.Reactive/ViewModels/IPullRequestReviewCommentViewModel.cs +++ b/src/GitHub.Exports.Reactive/ViewModels/IPullRequestReviewCommentViewModel.cs @@ -55,10 +55,12 @@ Task InitializeAsync( /// /// The pull request session. /// The thread that the comment is a part of. + /// Whether the comment thread is a pending review thread. /// Whether to start the placeholder in edit mode. Task InitializeAsPlaceholderAsync( IPullRequestSession session, ICommentThreadViewModel thread, + bool isPending, bool isEditing); } } \ No newline at end of file diff --git a/src/GitHub.Exports/Services/IMessageDraftStore.cs b/src/GitHub.Exports/Services/IMessageDraftStore.cs new file mode 100644 index 0000000000..a7ccf5233c --- /dev/null +++ b/src/GitHub.Exports/Services/IMessageDraftStore.cs @@ -0,0 +1,46 @@ +using System.Collections.Generic; +using System.Threading.Tasks; + +namespace GitHub.Services +{ + /// + /// Represents a store in which drafts of messages can be held for later recall. + /// + public interface IMessageDraftStore + { + /// + /// Tries to get a draft. + /// + /// The type to deserialize. + /// The key. + /// The secondary key. + /// The draft data if it exists, otherwise null. + Task GetDraft(string key, string secondaryKey) where T : class; + + /// + /// Gets all drafts with the specified key. + /// + /// The type to deserialize. + /// The key. + /// + /// A collection of tuples describing the secondary key and data of each draft. + /// + Task> GetDrafts(string key) where T : class; + + /// + /// Updates a draft. + /// + /// The type to serialize. + /// The key. + /// The secondary key. + /// The draft data. + Task UpdateDraft(string key, string secondaryKey, T data) where T : class; + + /// + /// Removes a draft from the store. + /// + /// The key. + /// The secondary key. + Task DeleteDraft(string key, string secondaryKey); + } +} \ No newline at end of file diff --git a/src/GitHub.InlineReviews/Commands/InlineCommentNavigationCommand.cs b/src/GitHub.InlineReviews/Commands/InlineCommentNavigationCommand.cs index d86aeb9114..fe810fc3df 100644 --- a/src/GitHub.InlineReviews/Commands/InlineCommentNavigationCommand.cs +++ b/src/GitHub.InlineReviews/Commands/InlineCommentNavigationCommand.cs @@ -6,6 +6,7 @@ using GitHub.InlineReviews.Services; using GitHub.InlineReviews.Tags; using GitHub.Logging; +using GitHub.Models; using GitHub.Services; using GitHub.Services.Vssdk.Commands; using Microsoft.VisualStudio; @@ -236,7 +237,8 @@ protected void ShowPeekComments( } } - var point = peekService.Show(textView, tag); + var side = tag.DiffChangeType == DiffChangeType.Delete ? DiffSide.Left : DiffSide.Right; + var point = peekService.Show(textView, side, tag.LineNumber); if (parameter?.MoveCursor != false) { diff --git a/src/GitHub.InlineReviews/GitHub.InlineReviews.csproj b/src/GitHub.InlineReviews/GitHub.InlineReviews.csproj index 7404ed2695..bc14031ba6 100644 --- a/src/GitHub.InlineReviews/GitHub.InlineReviews.csproj +++ b/src/GitHub.InlineReviews/GitHub.InlineReviews.csproj @@ -81,6 +81,7 @@ + @@ -88,7 +89,6 @@ - @@ -97,9 +97,7 @@ - - diff --git a/src/GitHub.InlineReviews/Peek/InlineCommentPeekableItem.cs b/src/GitHub.InlineReviews/Peek/InlineCommentPeekableItem.cs index 2f07f939f2..8adb814546 100644 --- a/src/GitHub.InlineReviews/Peek/InlineCommentPeekableItem.cs +++ b/src/GitHub.InlineReviews/Peek/InlineCommentPeekableItem.cs @@ -2,10 +2,12 @@ using System.Collections.Generic; using Microsoft.VisualStudio.Language.Intellisense; using GitHub.InlineReviews.ViewModels; +using GitHub.ViewModels; +using System.Reactive; namespace GitHub.InlineReviews.Peek { - class InlineCommentPeekableItem : IPeekableItem + class InlineCommentPeekableItem : IPeekableItem, IClosable { public InlineCommentPeekableItem(InlineCommentPeekViewModel viewModel) { @@ -17,6 +19,8 @@ public InlineCommentPeekableItem(InlineCommentPeekViewModel viewModel) public IEnumerable Relationships => new[] { InlineCommentPeekRelationship.Instance }; + public IObservable Closed => ViewModel.Close; + public IPeekResultSource GetOrCreateResultSource(string relationshipName) { return new InlineCommentPeekableResultSource(ViewModel); diff --git a/src/GitHub.InlineReviews/Tags/InlineCommentGlyphFactory.cs b/src/GitHub.InlineReviews/Tags/InlineCommentGlyphFactory.cs index 38b62be7d3..6d5c4ac613 100644 --- a/src/GitHub.InlineReviews/Tags/InlineCommentGlyphFactory.cs +++ b/src/GitHub.InlineReviews/Tags/InlineCommentGlyphFactory.cs @@ -7,6 +7,8 @@ using Microsoft.VisualStudio.Text.Editor; using Microsoft.VisualStudio.Text.Formatting; using GitHub.InlineReviews.Services; +using GitHub.Models; +using GitHub.Services; namespace GitHub.InlineReviews.Tags { @@ -72,12 +74,14 @@ bool OpenThreadView(InlineCommentTag tag) if (addTag != null) { - peekService.Show(textView, addTag); + var side = addTag.DiffChangeType == DiffChangeType.Delete ? DiffSide.Left : DiffSide.Right; + peekService.Show(textView, side, addTag.LineNumber); return true; } else if (showTag != null) { - peekService.Show(textView, showTag); + var side = showTag.DiffChangeType == DiffChangeType.Delete ? DiffSide.Left : DiffSide.Right; + peekService.Show(textView, side, showTag.LineNumber); return true; } diff --git a/src/GitHub.InlineReviews/ViewModels/InlineCommentPeekViewModel.cs b/src/GitHub.InlineReviews/ViewModels/InlineCommentPeekViewModel.cs index 80cc3d6192..3f2dab8aa3 100644 --- a/src/GitHub.InlineReviews/ViewModels/InlineCommentPeekViewModel.cs +++ b/src/GitHub.InlineReviews/ViewModels/InlineCommentPeekViewModel.cs @@ -165,8 +165,6 @@ async void LinesChanged(IReadOnlyList> lines) async Task UpdateThread() { - var placeholderBody = GetPlaceholderBodyToPreserve(); - Thread = null; threadSubscription?.Dispose(); @@ -179,28 +177,18 @@ async Task UpdateThread() var thread = file.InlineCommentThreads?.FirstOrDefault(x => x.LineNumber == lineNumber && ((leftBuffer && x.DiffLineType == DiffChangeType.Delete) || (!leftBuffer && x.DiffLineType != DiffChangeType.Delete))); - - Thread = factory.CreateViewModel(); + var vm = factory.CreateViewModel(); if (thread?.Comments.Count > 0) { - await Thread.InitializeAsync(session, file, thread.Comments[0].Review, thread, true); + await vm.InitializeAsync(session, file, thread.Comments[0].Review, thread, true); } else { - await Thread.InitializeNewAsync(session, file, lineNumber, side, true); + await vm.InitializeNewAsync(session, file, lineNumber, side, true); } - if (!string.IsNullOrWhiteSpace(placeholderBody)) - { - var placeholder = Thread.Comments.LastOrDefault(); - - if (placeholder?.EditState == CommentEditState.Placeholder) - { - await placeholder.BeginEdit.Execute(); - placeholder.Body = placeholderBody; - } - } + Thread = vm; } async Task SessionChanged(IPullRequestSession pullRequestSession) @@ -219,17 +207,5 @@ async Task SessionChanged(IPullRequestSession pullRequestSession) await UpdateThread(); } } - - string GetPlaceholderBodyToPreserve() - { - var lastComment = Thread?.Comments.LastOrDefault(); - - if (lastComment?.EditState == CommentEditState.Editing) - { - if (!lastComment.IsSubmitting) return lastComment.Body; - } - - return null; - } } } diff --git a/src/GitHub.VisualStudio.UI/Views/CommentView.xaml b/src/GitHub.VisualStudio.UI/Views/CommentView.xaml index 9bf463a44a..fc9003c4ee 100644 --- a/src/GitHub.VisualStudio.UI/Views/CommentView.xaml +++ b/src/GitHub.VisualStudio.UI/Views/CommentView.xaml @@ -154,6 +154,7 @@ TextWrapping="Wrap" VerticalAlignment="Center" GotFocus="ReplyPlaceholder_GotFocus" + Loaded="body_Loaded" SpellCheck.IsEnabled="True">