using System; using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; using Marr.Data; using Marr.Data.QGen; using NzbDrone.Common.Messaging; using NzbDrone.Core.Datastore.Events; using NzbDrone.Common; namespace NzbDrone.Core.Datastore { public interface IBasicRepository where TModel : ModelBase, new() { IEnumerable All(); int Count(); TModel Get(int id); IEnumerable Get(IEnumerable ids); TModel SingleOrDefault(); TModel Insert(TModel model); TModel Update(TModel model); TModel Upsert(TModel model); void Delete(int id); void Delete(TModel model); void InsertMany(IList model); void UpdateMany(IList model); void DeleteMany(List model); void Purge(); bool HasItems(); void DeleteMany(IEnumerable ids); void SetFields(TModel model, params Expression>[] properties); TModel Single(); PagingSpec GetPaged(PagingSpec pagingSpec); } public class BasicRepository : IBasicRepository where TModel : ModelBase, new() { private readonly IDatabase _database; private readonly IMessageAggregator _messageAggregator; private IDataMapper DataMapper { get { return _database.DataMapper; } } public BasicRepository(IDatabase database, IMessageAggregator messageAggregator) { _database = database; _messageAggregator = messageAggregator; } protected QueryBuilder Query { get { return DataMapper.Query(); } } protected void Delete(Expression> filter) { DataMapper.Delete(filter); } public IEnumerable All() { return DataMapper.Query().ToList(); } public int Count() { return DataMapper.Query().GetRowCount(); } public TModel Get(int id) { return DataMapper.Query().Single(c => c.Id == id); } public IEnumerable Get(IEnumerable ids) { var query = String.Format("Id IN ({0})", String.Join(",", ids)); var result = Query.Where(query).ToList(); if (result.Count != ids.Count()) { throw new ApplicationException("Expected query to return {0} rows but returned {1}".Inject(ids.Count(), result.Count)); } return result; } public TModel SingleOrDefault() { return All().SingleOrDefault(); } public TModel Single() { return All().Single(); } public TModel Insert(TModel model) { if (model.Id != 0) { throw new InvalidOperationException("Can't insert model with existing ID " + model.Id); } DataMapper.Insert(model); PublishModelEvent(model, RepositoryAction.Created); return model; } public TModel Update(TModel model) { if (model.Id == 0) { throw new InvalidOperationException("Can't update model with ID 0"); } DataMapper.Update(model, c => c.Id == model.Id); return model; } public void Delete(TModel model) { DataMapper.Delete(c => c.Id == model.Id); } public void InsertMany(IList models) { foreach (var model in models) { Insert(model); } } public void UpdateMany(IList models) { foreach (var model in models) { Update(model); } } public void DeleteMany(List models) { models.ForEach(Delete); } public TModel Upsert(TModel model) { if (model.Id == 0) { Insert(model); return model; } Update(model); return model; } public void Delete(int id) { DataMapper.Delete(c => c.Id == id); } public void DeleteMany(IEnumerable ids) { ids.ToList().ForEach(Delete); } public void Purge() { DataMapper.Delete(c => c.Id > -1); } public bool HasItems() { return Count() > 0; } public void SetFields(TModel model, params Expression>[] properties) { if (model.Id == 0) { throw new InvalidOperationException("Attempted to updated model without ID"); } DataMapper.Update() .Where(c => c.Id == model.Id) .ColumnsIncluding(properties) .Entity(model) .Execute(); } public virtual PagingSpec GetPaged(PagingSpec pagingSpec) { var pagingQuery = Query.OrderBy(pagingSpec.OrderByClause(), pagingSpec.ToSortDirection()) .Skip(pagingSpec.PagingOffset()) .Take(pagingSpec.PageSize); pagingSpec.Records = pagingQuery.ToList(); //TODO: Use the same query for count and records pagingSpec.TotalRecords = Count(); return pagingSpec; } private void PublishModelEvent(TModel model, RepositoryAction action) { if (PublishModelEvents) { _messageAggregator.PublishEvent(new ModelEvent(model, action)); } } protected virtual void OnModelChanged(IEnumerable models) { } protected virtual void OnModelDeleted(IEnumerable models) { } protected virtual bool PublishModelEvents { get { return false; } } } }