cleanup command will now also clear greetsettings and autpublish channels

Cleaned up some comments, changed grpc api
This commit is contained in:
Toastie 2024-10-11 19:58:25 +13:00
parent 2548947c58
commit 4f01c30396
Signed by: toastie_t0ast
GPG key ID: 27F3B6855AFD40A4
13 changed files with 212 additions and 122 deletions

View file

@ -12,13 +12,14 @@ namespace EllieBot.Generators
{
public readonly record struct MethodPermData
{
public readonly string Name;
public readonly string Value;
public readonly ImmutableArray<(string Name, string Value)> MethodPerms;
public readonly ImmutableArray<string> NoAuthRequired;
public MethodPermData(string name, string value)
public MethodPermData(ImmutableArray<(string Name, string Value)> methodPerms,
ImmutableArray<string> noAuthRequired)
{
Name = name;
Value = value;
MethodPerms = methodPerms;
NoAuthRequired = noAuthRequired;
}
}
@ -26,7 +27,7 @@ namespace EllieBot.Generators
[Generator]
public class GrpcApiPermGenerator : IIncrementalGenerator
{
public const string Attribute =
public const string GRPC_API_PERM_ATTRIBUTE =
"""
namespace EllieBot.GrpcApi;
@ -38,12 +39,25 @@ namespace EllieBot.Generators
}
""";
public const string GRPC_NO_AUTH_REQUIRED_ATTRIBUTE =
"""
namespace EllieBot.GrpcApi;
[System.AttributeUsage(System.AttributeTargets.Method)]
public class GrpcNoAuthRequiredAttribute : System.Attribute
{
}
""";
public void Initialize(IncrementalGeneratorInitializationContext context)
{
context.RegisterPostInitializationOutput(ctx => ctx.AddSource("GrpcApiPermAttribute.cs",
SourceText.From(Attribute, Encoding.UTF8)));
SourceText.From(GRPC_API_PERM_ATTRIBUTE, Encoding.UTF8)));
var enumsToGenerate = context.SyntaxProvider
context.RegisterPostInitializationOutput(ctx => ctx.AddSource("GrpcNoAuthRequiredAttribute.cs",
SourceText.From(GRPC_NO_AUTH_REQUIRED_ATTRIBUTE, Encoding.UTF8)));
var perms = context.SyntaxProvider
.ForAttributeWithMetadataName(
"EllieBot.GrpcApi.GrpcApiPermAttribute",
predicate: static (s, _) => s is MethodDeclarationSyntax,
@ -52,11 +66,24 @@ namespace EllieBot.Generators
.Select(static (x, _) => x!.Value)
.Collect();
context.RegisterSourceOutput(enumsToGenerate,
var all = context.SyntaxProvider
.ForAttributeWithMetadataName(
"EllieBot.GrpcApi.GrpcNoAuthRequiredAttribute",
predicate: static (s, _) => s is MethodDeclarationSyntax,
transform: static (ctx, _) => GetNoAuthMethodName(ctx.SemanticModel, ctx.TargetNode))
.Collect()
.Combine(perms)
.Select((x, _) => new MethodPermData(x.Right, x.Left));
context.RegisterSourceOutput(all,
static (spc, source) => Execute(source, spc));
}
private static MethodPermData? GetMethodSemanticTargets(SemanticModel model, SyntaxNode node)
private static string GetNoAuthMethodName(SemanticModel model, SyntaxNode node)
=> ((MethodDeclarationSyntax)node).Identifier.Text;
private static (string Name, string Value)? GetMethodSemanticTargets(SemanticModel model, SyntaxNode node)
{
var method = (MethodDeclarationSyntax)node;
@ -64,20 +91,14 @@ namespace EllieBot.Generators
var attr = method.AttributeLists
.SelectMany(x => x.Attributes)
.FirstOrDefault();
// .FirstOrDefault(x => x.Name.ToString() == "GrpcApiPermAttribute");
if (attr is null)
return null;
// if (model.GetSymbolInfo(attr).Symbol is not IMethodSymbol attrSymbol)
// return null;
return new MethodPermData(name, attr.ArgumentList.Arguments[0].ToString() ?? "__missing_perm__");
// return new MethodPermData(name, attrSymbol.Parameters[0].ContainingType.ToDisplayString() + "." + attrSymbol.Parameters[0].Name);
return (name, attr.ArgumentList?.Arguments[0].ToString() ?? "__missing_perm__");
}
private static void Execute(ImmutableArray<MethodPermData> fields, SourceProductionContext ctx)
private static void Execute(MethodPermData data, SourceProductionContext ctx)
{
using (var stringWriter = new StringWriter())
using (var sw = new IndentedTextWriter(stringWriter))
@ -92,11 +113,12 @@ namespace EllieBot.Generators
sw.Indent++;
sw.WriteLine("public static FrozenDictionary<string, GuildPerm> perms = new Dictionary<string, GuildPerm>()");
sw.WriteLine(
"private static FrozenDictionary<string, GuildPerm> _perms = new Dictionary<string, GuildPerm>()");
sw.WriteLine("{");
sw.Indent++;
foreach (var field in fields)
foreach (var field in data.MethodPerms)
{
sw.WriteLine("{{ \"{0}\", {1} }},", field.Name, field.Value);
}
@ -104,6 +126,21 @@ namespace EllieBot.Generators
sw.Indent--;
sw.WriteLine("}.ToFrozenDictionary();");
sw.WriteLine();
sw.WriteLine("private static FrozenSet<string> _noAuthRequired = new HashSet<string>()");
sw.WriteLine("{");
sw.Indent++;
foreach (var noauth in data.NoAuthRequired)
{
sw.WriteLine("{{ \"{0}\" }},", noauth);
}
sw.WriteLine("");
sw.Indent--;
sw.WriteLine("}.ToFrozenSet();");
sw.Indent--;
sw.WriteLine("}");

View file

@ -11,7 +11,7 @@ service GrpcGreet {
}
message GrpcGreetSettings {
optional uint64 channelId = 1;
string channelId = 1;
string message = 2;
bool isEnabled = 3;
GrpcGreetType type = 4;

View file

@ -62,7 +62,6 @@ public abstract class EllieContext : DbContext
public DbSet<ArchivedTodoListModel> TodosArchive { get; set; }
public DbSet<HoneypotChannel> HoneyPotChannels { get; set; }
// todo add guild colors
// public DbSet<GuildColors> GuildColors { get; set; }

View file

@ -37,27 +37,37 @@ public class AutoPublishService : IExecNoCommand, IReadyExecutor, IEService
});
}
// todo GUILDS
public async Task OnReadyAsync()
{
var creds = _creds.GetCreds();
await using var ctx = _db.GetDbContext();
var items = await ctx.GetTable<AutoPublishChannel>()
.Where(x => Linq2DbExpressions.GuildOnShard(x.GuildId, creds.TotalShards, _client.ShardId))
.ToListAsyncLinqToDB();
.Where(x => Linq2DbExpressions.GuildOnShard(x.GuildId, creds.TotalShards, _client.ShardId))
.ToListAsyncLinqToDB();
_enabled = items
.ToDictionary(x => x.GuildId, x => x.ChannelId)
.ToConcurrent();
.ToDictionary(x => x.GuildId, x => x.ChannelId)
.ToConcurrent();
_client.LeftGuild += ClientOnLeftGuild;
}
public async Task ClientOnLeftGuild(SocketGuild guild)
{
await using var ctx = _db.GetDbContext();
_enabled.TryRemove(guild.Id, out _);
await ctx.GetTable<AutoPublishChannel>()
.Where(x => x.GuildId == guild.Id)
.DeleteAsync();
}
public async Task<bool> ToggleAutoPublish(ulong guildId, ulong channelId)
{
await using var ctx = _db.GetDbContext();
var deleted = await ctx.GetTable<AutoPublishChannel>()
.DeleteAsync(x => x.GuildId == guildId && x.ChannelId == channelId);
.DeleteAsync(x => x.GuildId == guildId && x.ChannelId == channelId);
if (deleted != 0)
{
@ -66,21 +76,21 @@ public class AutoPublishService : IExecNoCommand, IReadyExecutor, IEService
}
await ctx.GetTable<AutoPublishChannel>()
.InsertOrUpdateAsync(() => new()
{
GuildId = guildId,
ChannelId = channelId,
DateAdded = DateTime.UtcNow,
},
old => new()
{
ChannelId = channelId,
DateAdded = DateTime.UtcNow,
},
() => new()
{
GuildId = guildId
});
.InsertOrUpdateAsync(() => new()
{
GuildId = guildId,
ChannelId = channelId,
DateAdded = DateTime.UtcNow,
},
old => new()
{
ChannelId = channelId,
DateAdded = DateTime.UtcNow,
},
() => new()
{
GuildId = guildId
});
_enabled[guildId] = channelId;

View file

@ -207,6 +207,18 @@ public sealed class CleanupService : ICleanupService, IReadyExecutor, IEService
.Contains(x.GuildId))
.DeleteAsync();
// delete autopublish channels
await ctx.GetTable<AutoPublishChannel>()
.Where(x => !tempTable.Select(x => x.GuildId)
.Contains(x.GuildId))
.DeleteAsync();
// delete greet settings
await ctx.GetTable<GreetSettings>()
.Where(x => !tempTable.Select(x => x.GuildId)
.Contains(x.GuildId))
.DeleteAsync();
return new()
{
GuildCount = guildIds.Keys.Count,

View file

@ -100,10 +100,6 @@ public sealed class AiAssistantService
using var client = _httpFactory.CreateClient();
// todo customize according to the bot's config
// - CurrencyName
// -
using var response = await client.SendAsync(request);
if (response.StatusCode == HttpStatusCode.TooManyRequests)

View file

@ -14,7 +14,6 @@ public class ExprsSvc : GrpcExprs.GrpcExprsBase, IEService
_svc = svc;
}
[GrpcApiPerm(GuildPerm.Administrator)]
public override async Task<AddExprReply> AddExpr(AddExprRequest request, ServerCallContext context)
{
EllieExpression expr;
@ -45,7 +44,6 @@ public class ExprsSvc : GrpcExprs.GrpcExprsBase, IEService
};
}
[GrpcApiPerm(GuildPerm.Administrator)]
public override async Task<GetExprsReply> GetExprs(GetExprsRequest request, ServerCallContext context)
{
var (exprs, totalCount) = await _svc.FindExpressionsAsync(request.GuildId, request.Query, request.Page);
@ -66,7 +64,6 @@ public class ExprsSvc : GrpcExprs.GrpcExprsBase, IEService
return reply;
}
[GrpcApiPerm(GuildPerm.Administrator)]
public override async Task<Empty> DeleteExpr(DeleteExprRequest request, ServerCallContext context)
{
await _svc.DeleteAsync(request.GuildId, new kwum(request.Id));

View file

@ -23,12 +23,11 @@ public sealed class GreetByeSvc : GrpcGreet.GrpcGreetBase, IEService
{
Message = conf.MessageText,
Type = (GrpcGreetType)conf.GreetType,
ChannelId = conf.ChannelId ?? 0,
ChannelId = conf.ChannelId?.ToString() ?? string.Empty,
IsEnabled = conf.IsEnabled,
};
}
[GrpcApiPerm(GuildPerm.Administrator)]
public override async Task<GrpcGreetSettings> GetGreetSettings(GetGreetRequest request, ServerCallContext context)
{
var guildId = request.GuildId;
@ -38,7 +37,6 @@ public sealed class GreetByeSvc : GrpcGreet.GrpcGreetBase, IEService
return ToConf(conf);
}
[GrpcApiPerm(GuildPerm.Administrator)]
public override async Task<UpdateGreetReply> UpdateGreet(UpdateGreetRequest request, ServerCallContext context)
{
var gid = request.GuildId;
@ -48,7 +46,7 @@ public sealed class GreetByeSvc : GrpcGreet.GrpcGreetBase, IEService
var type = GetGreetType(s.Type);
await _gs.SetMessage(gid, GetGreetType(s.Type), msg);
await _gs.SetGreet(gid, s.ChannelId, type, s.IsEnabled);
await _gs.SetGreet(gid, ulong.Parse(s.ChannelId), type, s.IsEnabled);
var settings = await _gs.GetGreetSettingsAsync(gid, type);
if (settings is null)
@ -60,7 +58,6 @@ public sealed class GreetByeSvc : GrpcGreet.GrpcGreetBase, IEService
};
}
[GrpcApiPerm(GuildPerm.Administrator)]
public override Task<TestGreetReply> TestGreet(TestGreetRequest request, ServerCallContext context)
=> TestGreet(request.GuildId, request.ChannelId, request.UserId, request.Type);

View file

@ -19,6 +19,7 @@ public sealed class OtherSvc : GrpcOther.GrpcOtherBase, IEService
private readonly WaifuService _waifus;
private readonly ICoordinator _coord;
private readonly IStatsService _stats;
private readonly IBotCache _cache;
public OtherSvc(
DiscordSocketClient client,
@ -26,7 +27,8 @@ public sealed class OtherSvc : GrpcOther.GrpcOtherBase, IEService
ICurrencyService cur,
WaifuService waifus,
ICoordinator coord,
IStatsService stats)
IStatsService stats,
IBotCache cache)
{
_client = client;
_xp = xp;
@ -34,35 +36,9 @@ public sealed class OtherSvc : GrpcOther.GrpcOtherBase, IEService
_waifus = waifus;
_coord = coord;
_stats = stats;
_cache = cache;
}
public override async Task<GetGuildsReply> GetGuilds(Empty request, ServerCallContext context)
{
var guilds = await _client.GetGuildsAsync(CacheMode.CacheOnly);
var reply = new GetGuildsReply();
var userId = context.GetUserId();
var toReturn = new List<IGuild>();
foreach (var g in guilds)
{
var user = await g.GetUserAsync(userId, CacheMode.AllowDownload);
if (user.GuildPermissions.Has(GuildPermission.Administrator))
toReturn.Add(g);
}
reply.Guilds.AddRange(toReturn
.Select(x => new GuildReply()
{
Id = x.Id,
Name = x.Name,
IconUrl = x.IconUrl
}));
return reply;
}
[GrpcApiPerm(GuildPerm.Administrator)]
public override async Task<GetTextChannelsReply> GetTextChannels(
GetTextChannelsRequest request,
ServerCallContext context)
@ -81,6 +57,35 @@ public sealed class OtherSvc : GrpcOther.GrpcOtherBase, IEService
return reply;
}
[GrpcNoAuthRequired]
public override async Task<GetGuildsReply> GetGuilds(Empty request, ServerCallContext context)
{
var guilds = await _client.GetGuildsAsync(CacheMode.CacheOnly);
var reply = new GetGuildsReply();
var userId = context.GetUserId();
var toReturn = new List<IGuild>();
foreach (var g in guilds)
{
var user = await g.GetUserAsync(userId);
if (user.GuildPermissions.Has(GuildPermission.Administrator))
toReturn.Add(g);
}
reply.Guilds.AddRange(toReturn
.Select(x => new GuildReply()
{
Id = x.Id,
Name = x.Name,
IconUrl = x.IconUrl
}));
return reply;
}
[GrpcNoAuthRequired]
public override async Task<CurrencyLbReply> GetCurrencyLb(GetLbRequest request, ServerCallContext context)
{
var users = await _cur.GetTopRichest(_client.CurrentUser.Id, request.Page, request.PerPage);
@ -103,6 +108,7 @@ public sealed class OtherSvc : GrpcOther.GrpcOtherBase, IEService
return reply;
}
[GrpcNoAuthRequired]
public override async Task<XpLbReply> GetXpLb(GetLbRequest request, ServerCallContext context)
{
var users = await _xp.GetGlobalUserXps(request.Page);
@ -127,6 +133,7 @@ public sealed class OtherSvc : GrpcOther.GrpcOtherBase, IEService
return reply;
}
[GrpcNoAuthRequired]
public override async Task<WaifuLbReply> GetWaifuLb(GetLbRequest request, ServerCallContext context)
{
var waifus = await _waifus.GetTopWaifusAtPage(request.Page, request.PerPage);
@ -142,11 +149,15 @@ public sealed class OtherSvc : GrpcOther.GrpcOtherBase, IEService
return reply;
}
public override Task<GetShardStatusesReply> GetShardStatuses(Empty request, ServerCallContext context)
[GrpcNoAuthRequired]
public override async Task<GetShardStatusesReply> GetShardStatuses(Empty request, ServerCallContext context)
{
var reply = new GetShardStatusesReply();
// todo cache
await _cache.GetOrAddAsync<List<ShardStatus>>("coord:statuses",
() => Task.FromResult(_coord.GetAllShardStatuses().ToList())!,
TimeSpan.FromMinutes(1));
var shards = _coord.GetAllShardStatuses();
reply.Shards.AddRange(shards.Select(x => new ShardStatusReply()
@ -157,10 +168,10 @@ public sealed class OtherSvc : GrpcOther.GrpcOtherBase, IEService
LastUpdate = Timestamp.FromDateTime(x.LastUpdate),
}));
return Task.FromResult(reply);
return reply;
}
[GrpcApiPerm(GuildPerm.Administrator)]
public override async Task<GetServerInfoReply> GetServerInfo(ServerInfoRequest request, ServerCallContext context)
{
var info = await _stats.GetGuildInfoAsync(request.GuildId);

View file

@ -5,12 +5,13 @@ namespace EllieBot.GrpcApi;
public sealed partial class GrpcApiPermsInterceptor : Interceptor
{
private const GuildPerm DEFAULT_PERMISSION = GuildPermission.Administrator;
private readonly DiscordSocketClient _client;
public GrpcApiPermsInterceptor(DiscordSocketClient client)
{
_client = client;
Log.Information("interceptor created");
}
public override async Task<TResponse> UnaryServerHandler<TRequest, TResponse>(
@ -20,42 +21,45 @@ public sealed partial class GrpcApiPermsInterceptor : Interceptor
{
try
{
Log.Information("Starting receiving call. Type/Method: {Type} / {Method}",
MethodType.Unary,
context.Method);
var method = context.Method[(context.Method.LastIndexOf('/') + 1)..];
// get metadata
var metadata = context
.RequestHeaders
.ToDictionary(x => x.Key, x => x.Value);
Log.Information("grpc | g: {GuildId} | u: {UserID} | cmd: {Method}",
metadata.TryGetValue("guildid", out var gidString) ? gidString : "none",
metadata.TryGetValue("userid", out var uidString) ? uidString : "none",
method);
// there always has to be a user who makes the call
if (!metadata.ContainsKey("userid"))
throw new RpcException(new Status(StatusCode.Unauthenticated, "userid has to be specified"));
throw new RpcException(new(StatusCode.Unauthenticated, "userid has to be specified."));
var method = context.Method[(context.Method.LastIndexOf('/') + 1)..];
// get the method name without the service name
if (perms.TryGetValue(method, out var perm))
// if the method is explicitly marked as not requiring auth
if (_noAuthRequired.Contains(method))
return await continuation(request, context);
// otherwise the method requires auth, and if it requires auth then the guildid has to be specified
if (!metadata.ContainsKey("guildid"))
throw new RpcException(new(StatusCode.Unauthenticated, "guildid has to be specified."));
var userId = ulong.Parse(metadata["userid"]);
var guildId = ulong.Parse(gidString);
// check if the user has the required permission
if (_perms.TryGetValue(method, out var perm))
{
Log.Information("Required permission for {Method} is {Perm}",
method,
perm);
var userId = ulong.Parse(metadata["userid"]);
var guildId = ulong.Parse(metadata["guildid"]);
IGuild guild = _client.GetGuild(guildId);
var user = guild is null ? null : await guild.GetUserAsync(userId);
if (user is null)
throw new RpcException(new Status(StatusCode.NotFound, "User not found"));
if (!user.GuildPermissions.Has(perm))
throw new RpcException(new Status(StatusCode.PermissionDenied,
$"You need {perm} permission to use this method"));
await EnsureUserHasPermission(guildId, userId, perm);
}
else
{
Log.Information("No permission required for {Method}", method);
// if not then use the default, which is Administrator permission
await EnsureUserHasPermission(guildId, userId, DEFAULT_PERMISSION);
}
return await continuation(request, context);
@ -66,4 +70,17 @@ public sealed partial class GrpcApiPermsInterceptor : Interceptor
throw;
}
}
private async Task EnsureUserHasPermission(ulong guildId, ulong userId, GuildPerm perm)
{
IGuild guild = _client.GetGuild(guildId);
var user = guild is null ? null : await guild.GetUserAsync(userId);
if (user is null)
throw new RpcException(new Status(StatusCode.NotFound, "User not found"));
if (!user.GuildPermissions.Has(perm))
throw new RpcException(new Status(StatusCode.PermissionDenied,
$"You need {perm} permission to use this method"));
}
}

View file

@ -41,6 +41,19 @@ public class GrpcApiService : IEService, IReadyExecutor
var interceptor = new GrpcApiPermsInterceptor(_client);
var serverCreds = ServerCredentials.Insecure;
if (creds.GrpcApi is
{
CertPrivateKey: not null and not "",
CertChain: not null and not ""
} cert)
{
serverCreds = new SslServerCredentials(
new[] { new KeyCertificatePair(cert.CertChain, cert.CertPrivateKey) });
}
_app = new Server()
{
Services =
@ -51,7 +64,7 @@ public class GrpcApiService : IEService, IReadyExecutor
},
Ports =
{
new(host, port, ServerCredentials.Insecure),
new(host, port, serverCreds),
}
};
@ -59,8 +72,9 @@ public class GrpcApiService : IEService, IReadyExecutor
Log.Information("Grpc Api Server started on port {Host}:{Port}", host, port);
}
catch
catch (Exception ex)
{
Log.Error(ex, "Error starting Grpc Api Server");
_app?.ShutdownAsync().GetAwaiter().GetResult();
}

View file

@ -6,7 +6,7 @@ namespace EllieBot.Common;
public sealed class Creds : IBotCreds
{
[Comment("""DO NOT CHANGE""")]
public int Version { get; set; } = 12;
public int Version { get; set; } = 13;
[Comment("""Bot token. Do not share with anyone ever -> https://discordapp.com/developers/applications/""")]
public string Token { get; set; }
@ -293,8 +293,8 @@ public sealed class Creds : IBotCreds
public sealed record GrpcApiConfig
{
public bool Enabled { get; set; } = false;
public string CertPath { get; set; } = string.Empty;
public string CertPassword { get; set; } = string.Empty;
public string CertChain { get; set; } = string.Empty;
public string CertPrivateKey { get; set; } = string.Empty;
public string Host { get; set; } = "localhost";
public int Port { get; set; } = 43120;
}

View file

@ -140,9 +140,9 @@ public sealed class BotCredsProvider : IBotCredsProvider
creds.BotCache = BotCacheImplemenation.Memory;
}
if (creds.Version < 12)
if (creds.Version < 13)
{
creds.Version = 12;
creds.Version = 13;
File.WriteAllText(CREDS_FILE_NAME, Yaml.Serializer.Serialize(creds));
}
}