EFCore查询语句生成流程、让EFCore支持批量Update/Delete/MergeInto

引子

之前发现了一款叫 EFCore.BulkExtensions 的 nuget 包。里面提供了大量的 BulkInsertOrUpdateOrDelete 和 BatchUpdate 的拓展,可以很方便的解决批量更新和删除的问题,不用让 EFCore 一条一条的删除和更新。

其中几个比较有用的函数签名是

Task<int> BatchDeleteAsync(this IQueryable<T> queryable);
Task<int> BatchUpdateAsync(this IQueryable<T> queryable, Expression<Func<T, T>> updateExpression);

但是在升级到 ASP.NET Core 3.1 的时候,所有 Where 中的 someArray.Contains(i.Key) 全部挂掉了。而我的程序里用这一语句比较多,遂下载了其源代码并合并了当时作者几个月都没合并的一个PR。

研究代码,总结了该程序的基本运行过程:

  1. 通过反射获取各种私有变量来访问到 DbContext
  2. updateExpression 由这个包自己访问表达式树获得
  3. 让 IQueryable 执行 GetEnumerator 让 EFCore 生成对应的 Select 语句,进行字符串拼接
  4. 由 DbContext.Database.ExecuteSqlRaw 来完成语句执行

但是这过程有几个问题:

  1. 有几种句式 updateExpression 会翻译不了
  2. 由其原来实现的 updateExpression 翻译后的某些参数的 SQL 类型不对
  3. 我需要一个 INSERT INTO SELECT FROM 的句式,它不支持
  4. 我需要一个 upsert 功能,但是原来的 BulkInsertOrUpdate 不能在原表基础上操作

遂研究 IQueryable.Provider.Execute<T> 是什么执行流程。

语句生成过程

我觉得在翻代码的过程中,有这么一首歌比较符合我的心情:如果你愿意一层一层一层一层的拨开我的心,你会发现,你会讶异,你是我最压抑最深处的秘密。

  1. 调用 QueryCompiler.ExtractParameters,将其中的闭包捕捉变量参数化
  2. 检查是否已经缓存了这个查询表达式,如果没有则转入 QueryCompilationContext 处理,否则转到8
  3. QueryTranslationPreprocessor 处理,在原来的表达式树上先跳舞
  4. QueryableMethodTranslatingExpressionVisitor 将原来的表达式树翻译成一个 ShapedQueryExpression,而这一个表达式则包含了几个部分:SelectExpressionShaperExpressionResultCardinality。其中前者是可以翻译成 SQL 语句的表达式,中间的是将查询出来的元组映射到实体类型,最后一个是查询的维度(Enumerable、Single、SingleOrDefault)
  5. QueryTranslationPostprocessor 处理,其中比较重要的是将查询的字段加入 SELECT 的 Projection 列表
  6. ShapedQueryCompilingExpressionVisitorShapedQueryExpression 缓存,并转换成为 IRelationcalCommandCache,然后构造一个 QueryableEnumerable 的 NewExpression。前者包含了该查询语句需要的参数、查询语法树、查询字符串,后者是进行语句执行的类
  7. 将上述 NewExpression 和将 QueryCompilationContext 中的查询参数加到 QueryContext 中的语句合并成为一个代码块,然后 Lambda Compile
  8. 生成 DbCommandIRelationcalCommandCache 获取字符串并加入各种参数进行查询

翻译结束了,查询到这里也就可以开始了。

支持批量操作?

IRelationalCommandCache 是怎么生成字符串的呢?没错,就是 QuerySqlGenerator 啦。

那么,也就是说,我们能过拿到 Select Expression 的话,一切都好说。

上述过程中,最后的 IRelationalCommandCache 中会包含这个 SelectExpression。我们可以魔改这个啊!

DELETE 语句的生成比较简单。我们构建一个 DeleteExpression 类,将要删除的 Table、删除中的 Predicate、删除个数限制 Limit、原来的一些 Join 全部获取出来,就好了。然后在我们自己继承的 SqlServerQuerySqlGenerator 中实现这个部分。

INSERT INTO SELECT 也比较简单,只要构建一个 InsertIntoSelectExpression 类,将要插入的表 Table 和 SelectExpression 保存起来,就好了。

UPDATE SET 可能比较麻烦。但是我们可以骚操作啊!将那个 updateExpression 变成 Select 的字段,然后再读取 SelectExpression 中的 ProjectionExpression 不就好了吗~我真是个小天才。

MERGE INTO 是最烦的,因为结构过于复杂,涉及到 Target、Source、JoinPredicate、Limit、Matched、NotMatchedByTarget、NotMatchedBySource。过程中还要实现一些表的更名之类的。目前我只是实现了这些,但是想做出 Matched When 功能以后再发布到 nuget 上,这个实现实在是过于复杂,不知道有没有人帮帮我啊 TAT。

由于翻译 SqlExpression 最方便还是基于 QuerySqlGenerator 操作,所以就写一个 EnhancedQuerySqlGenerator 类来满足我们的需求,并在 DbContextOptionsBuilder 那边将这个 Factory 替换掉。

实现了这些,GitHub 地址:Microsoft.EntityFrameworkCore.Bulk,可以在 github packages 上下载目前版本的 nuget 包。

另外 src/Internal/TranslationGoThrough.cs 中有上述语句生成过程的一个缩影,和系统版本几乎一致,唯一不同的是修改了 ExtractParameters 函数。

因为原来的 Extract 过程有一个事情很诡异:在生成参数的时候,我们可以进行一些本地执行,但是如果不阻止某些本地执行程的话,可能会导致 UPDATE 语句的字段全部空。例如 updateExpression 中没有利用到原表的参数并且不捕捉闭包变量的时候,那么不会被本地执行,但是如果没有利用到原表的参数还捕捉闭包变量的时候,它就会被直接本地执行,字段空啦~(确实不懂他们这段代码逻辑怎么写的,你生成查询的时候优化这个的话,怎么不把前面一个也优化掉啊……

ASP.NET Core 修改 EndpointRouting 的链接生成行为

微软在 ASP.NET Core 2.2 时期引入了 EndpointRouting,并且在之后的 3.0 / 3.1 进行了很多的升级改造。我前段时间刚刚把网站升级到了 3.1,之前一直在使用 2.1 的兼容模式,并且内部有很多的属性路由的配置,而没有使用 DefaultRoute 那条规则。按照官方的文档将 UseMvc 那套替换成了 EndpointRouting 那一套。然后发现,很多的链接生成出了问题。

由于之前略微读过 AnchorTagHelper 的源码,大约知道问题出在了 UrlHelper 身上,将。随后发现, 在使用终结点路由后,UrlHelper 的实现类型变成了 EndpointRoutingUrlHelper。而后者则在内部调用了 LinkGenerator,其默认实现为 DefaultLinkGenerator

LinkGenerator 的实现使用到了几个比较重要的对象:EndpointDataSource 是所有终结点的集合,TemplateBinderFactory 是根据终结点的 RoutePattern 生成 TemplateBinder 的工厂,TemplateBinder 是一个保存了 RoutePattern 内部信息(例如默认RouteValues、链接模板)并提供实际链接生成的工具。由于微软在这方面的代码编写中全都 internal 了,且代码中注释较少,在接下来我将介绍他们的基本工作流程,并给出一个能够基本做到“兼容”、不修改太多代码的方法。此处的解释仅仅针对属性路由。

寻找终结点

请关注 GetEndpoints<TAddress>(TAddress address) 函数。此处,TAddress 取值有两种:string、RouteValuesAddress。

当 TAddress 为 string 时,请回顾 RouteAttribute 中的 Name 属性。没错就是这个,在 anchor 链接使用 asp-route=”something” 的时候,路由的查找是根据此处进行的。在生成终结点时,程序已经确保了所有的 RouteName 都唯一。查找直接找字典就好了。

当 TAddress 为 RouteValuesAddress 时,是使用 asp-area、asp-controller、asp-page、asp-action、asp-route-xxx 时进行的寻址方式。此 Address 由三个部分组成,其中包括 RouteName、ExplicitValues、AmbientValues。根据 RoutePattern 中的 RequiredValues 来检查 ExplicitValues 和 AmbientValues 的值。

结束后返回一个终结点列表,其中包含可能匹配的结果。我们针对每个可能的终结点,尝试匹配路由模板。

尝试匹配路由模板

生成一个 TemplateBinder,这个对象提供三个函数:GetValues、TryProcessConstraints、TryBindValues。

GetValues 会根据 RoutePattern 和两种 RouteValue 生成一个最终匹配列表,使用了 Ambient Values Invalidation Algorithm,在接下来介绍。

TryProcessConstraints 是根据上方匹配,检查是否满足路由值的限制。

TryBindValues 是将获取的 RouteValues 生成最后的终结点链接。

隐式路由值失效算法

首先介绍几个字段,在接下来会被算法使用到。

  • _slots 是一个 KVP<string, object>,其 Key 为路由值的名称,Value 为 null。其中前几个字段依次是 pattern 的参数,后几个字段是filter默认值。

  • _requiredKeys 是一个 string 数组,是以 RoutePattern 的字段中提取的。

  • _defaults 是默认路由值的字典。

  • _pattern 是 RoutePattern。

  • ambientValues 是目前页面的路由值字典。

  • values 是即将导航的页面的路由值字典,一般由asp-action等等计算得到。

最开始会将“需要填的空格”复制一份,以供本次计算使用。

  • 首先将所有需要填的空在 values 中寻找一遍,决定使用该值或者 explicit null 值。

  • 考虑针对每个 requiredKeys 是否继续复制隐式路由值。

    • 如果 ambientValues 为空引用,那么不复制。
    • 对于每一个在上一步中确定的显式值,检查和隐式路由值是否相同。检查过程会考虑值相等、any值匹配、显式null值。
  • 对于每一个路由参数,检查是否存在显式值、隐式值。
    • 如果还可以复制隐式值,那么检查此处的显式和隐式值是否匹配。如果不匹配,那么不再复制后面的隐式值。
    • 如果不能再复制隐式值,并且没有显式值,但它确实是路由需要的值,而且存在隐式值,并且和要求的值相等,那么使用它。
    • 如果在以上两种情况中匹配成功,那么将它加入 accepted 匹配成功的路由字典。
    • 如果是可选参数或者通配参数,那么将它从路由字典中忽略。
    • 如果没有匹配但是有默认值,那么使用默认值。
    • 以上过程均没有对应的话,认为此Endpoint匹配失败,不适用于此结果。
  • 对每个filter字段进行检查。
    • 如果存在显式值,检查两者是否相等,如果不相等那么就是匹配失败
    • 如果不存在显式值,那么不加入 accepted 匹配成功的路由字典。
  • 如果存在没有处理的显式值,加入 combined 合并后的路由字典。

  • 对于一些非参数的隐式值,需要加入到 combined 字典,这样可以对路由限制可见。

破坏的部分

以上对于 /[area]/{area_id}/[controller]/[action] 非常的不友好。

可以通过将该 TemplateBinder 的 _requiredValues 清空来达到这一点操作。但是这是 readonly 的 private 字段,所以使用反射解决。

using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Routing.Template;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using System;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;

namespace Microsoft.AspNetCore.Routing
{
    public sealed class OrderLinkGenerator : LinkGenerator, IDisposable
    {
        private readonly LinkGenerator inner;
        private readonly TemplateBinderFactory _binderFactory;
        private readonly Func<RouteEndpoint, TemplateBinder> _createTemplateBinder;
        private readonly FieldInfo _requiredKeys;
        const string typeName = "Microsoft.AspNetCore.Routing.DefaultLinkGenerator";
        internal static Type typeInner;

        public OrderLinkGenerator(
            ParameterPolicyFactory parameterPolicyFactory,
            TemplateBinderFactory binderFactory,
            EndpointDataSource dataSource,
            IOptions<RouteOptions> routeOptions,
            IServiceProvider serviceProvider)
        {
            if (typeInner.FullName != typeName)
                throw new NotImplementedException();
            var logger = serviceProvider.GetService(typeof(ILogger<>).MakeGenericType(typeInner));
            var autoFlag = BindingFlags.NonPublic | BindingFlags.Instance;

            var args = new object[]
            {
                parameterPolicyFactory,
                binderFactory,
                dataSource,
                routeOptions,
                logger,
                serviceProvider
            };

            var ctorInfo = typeInner.GetConstructors()[0];
            var newExp = Expression.New(ctorInfo, args.Select(o => Expression.Constant(o)));
            var ctor = Expression.Lambda<Func<LinkGenerator>>(newExp).Compile();
            inner = ctor();

            _binderFactory = binderFactory;
            _createTemplateBinder = CreateTemplateBinder;
            var fieldInfo = typeInner.GetField(nameof(_createTemplateBinder), autoFlag);
            fieldInfo.SetValue(inner, _createTemplateBinder);

            _requiredKeys = typeof(TemplateBinder).GetField(nameof(_requiredKeys), autoFlag);
        }

        private TemplateBinder CreateTemplateBinder(RouteEndpoint endpoint)
        {
            /*
             * The following code section is disabled
             * for its change to RoutePattern may cause
             * errors.
             * 
             * var rawText = endpoint.RoutePattern.RawText;
             * var rv = endpoint.RoutePattern.RequiredValues as RouteValueDictionary;
             *
             * if (rawText != null)
             * {
             *     var m = Regex.Matches(rawText, "\\{(\\w+)\\}");
             *     for (int i = 0; i < m.Count; i++)
             *         rv.Add(m[i].Value.TrimStart('{').TrimEnd('}'), RoutePattern.RequiredValueAny);
             * }
             * 
             * A better solution is to disable the _requiredKeys.
             */

            var binder = _binderFactory.Create(endpoint.RoutePattern);
            _requiredKeys.SetValue(binder, Array.Empty<string>());
            return binder;
        }

        public override string GetPathByAddress<TAddress>(HttpContext httpContext, TAddress address, RouteValueDictionary values, RouteValueDictionary ambientValues = null, PathString? pathBase = null, FragmentString fragment = default, LinkOptions options = null) =>
            inner.GetPathByAddress(httpContext, address, values, ambientValues, pathBase, fragment, options);
        public override string GetPathByAddress<TAddress>(TAddress address, RouteValueDictionary values, PathString pathBase = default, FragmentString fragment = default, LinkOptions options = null) =>
            inner.GetPathByAddress(address, values, pathBase, fragment, options);
        public override string GetUriByAddress<TAddress>(HttpContext httpContext, TAddress address, RouteValueDictionary values, RouteValueDictionary ambientValues = null, string scheme = null, HostString? host = null, PathString? pathBase = null, FragmentString fragment = default, LinkOptions options = null) =>
            inner.GetUriByAddress(httpContext, address, values, ambientValues, scheme, host, pathBase, fragment, options);
        public override string GetUriByAddress<TAddress>(TAddress address, RouteValueDictionary values, string scheme, HostString host, PathString pathBase = default, FragmentString fragment = default, LinkOptions options = null) =>
            inner.GetUriByAddress(address, values, scheme, host, pathBase, fragment, options);
        public void Dispose() => ((IDisposable)inner).Dispose();
    }
}

最后在依赖注入容器中替换即可。

public static IMvcBuilder ReplaceLinkGenerator(this IMvcBuilder mvc)
{
    var old = mvc.Services.FirstOrDefault(s => s.ServiceType == typeof(LinkGenerator));
    OrderLinkGenerator.typeInner = old.ImplementationType;
    mvc.Services.Replace(ServiceDescriptor.Singleton<LinkGenerator, OrderLinkGenerator>());
    return mvc;
}

ASP.NET Core中使用Basic Authentication

前段时间需要给某个项目接入Web API,并且使用的是基于HTTP Header的Basic Authorzation。

最开始的写法是在appsettings.json里开辟字段存储用户名和密码,然后给对应的Controller加IActionFilter,后来随着项目的升级,想着接入ASP.NET Core自带的用户系统。

翻了翻Microsoft Docs,并没有找到一些科学的内容。后来在NuGet上找到了idunno.Authentication.Basic这样的一个NuGet包。

这是其GitHub Repo上提供的代码示例。

public void ConfigureServices(IServiceCollection services)
{
    services.AddAuthentication(BasicAuthenticationDefaults.AuthenticationScheme)
            .AddBasic(options =>
            {
                options.Realm = "idunno";
                options.Events = new BasicAuthenticationEvents
                {
                    OnValidateCredentials = context =>
                    {
                        if (context.Username == context.Password)
                        {
                            var claims = new[]
                            {
                                new Claim(
                                    ClaimTypes.NameIdentifier, 
                                    context.Username, 
                                    ClaimValueTypes.String, 
                                    context.Options.ClaimsIssuer),
                                new Claim(
                                    ClaimTypes.Name, 
                                    context.Username, 
                                    ClaimValueTypes.String, 
                                    context.Options.ClaimsIssuer)
                            };

                            context.Principal = new ClaimsPrincipal(
                                new ClaimsIdentity(claims, context.Scheme.Name));
                            context.Success();
                        }

                        return Task.CompletedTask;
                    }
                };
            });

    // All the other service configuration.
}

在此之后,给需要使用Basic的Controller添加

[Authorize(AuthenticationSchemes = "Basic")]

即可。


众所周知,ASP.NET Core使用的是基于Claim的认证与授权。

如果想使用自带的用户系统(services.AddIdentity()那套),则直接使用SignInManager就好啦。

那么大致的验证代码(OnValidateCredentials)就是

private static async Task ValidateAsyncOld(ValidateCredentialsContext context)
{
    var signInManager = context.HttpContext.RequestServices
        .GetRequiredService<SignInManager<TUser>>();
    var userManager = signInManager.UserManager;
    var user = await userManager.FindByNameAsync(context.Username);

    if (user == null)
    {
        context.Fail("User not found.");
        return;
    }

    var checkPassword = await signInManager.CheckPasswordSignInAsync(user, context.Password, false);
    if (!checkPassword.Succeeded)
    {
        context.Fail("Login failed, password not match.");
        return;
    }

    context.Principal = await signInManager.CreateUserPrincipalAsync(user);
    context.Success();
}

此后发现验证的性能问题。

显然每次访问API接口都会进行四次数据库查询!(Users一次,UserClaims一次,Roles一次,RoleClaims一次)

在轮询的接口中出现了非常严重的性能问题,毕竟主要程序代码只有一个查询,而用户信息就要查询四次。


不过不要紧,我们还有MemoryCache。

将用户信息缓存在内存中,并且定期清除,让几百次登录同一账户查询才生成一次ClaimsIdentity就会将查询用户信息的时间均摊的更低了。

同时由于我的这个项目不使用UserClaims和RoleClaims,也可以在ClaimsIdentity生成过程中忽略这两个步骤。在用户信息的查询中也可以省略掉一堆一堆一堆一堆不需要的字段。

根据SignInManager、UserManager等的源码,最后挖掘出来了四个接口需要使用

  • IPasswordHasher<TUser>
  • IOptions<IdentityOptions>
  • IdentityDbContext<...>
  • IUserClaimsPrincipalFactory<TUser> (已经展开到最后代码中)

这样就不需要使用UserManager、SignInManager而底层的直接生成ClaimsIdentity和ClaimsPrincipal。

private readonly static IMemoryCache _cache =
    new MemoryCache(new MemoryCacheOptions()
    {
        Clock = new Microsoft.Extensions.Internal.SystemClock()
    });

private static async Task ValidateAsync(ValidateCredentialsContext context)
{
    var dbContext = context.HttpContext.RequestServices
        .GetRequiredService<TContext>();
    var normusername = context.Username.ToUpper();

    var user = await _cache.GetOrCreateAsync("`" + normusername.ToLower(), async entry =>
    {
        var value = await dbContext.Users
            .Where(u => u.NormalizedUserName == normusername)
            .Select(u => new { u.Id, u.UserName, u.PasswordHash, u.SecurityStamp })
            .FirstOrDefaultAsync();
        entry.AbsoluteExpirationRelativeToNow = TimeSpan.FromMinutes(5);
        return value;
    });

    if (user == null)
    {
        context.Fail("User not found.");
        return;
    }

    var passwordHasher = context.HttpContext.RequestServices
        .GetRequiredService<IPasswordHasher<TUser>>();

    var attempt = passwordHasher.VerifyHashedPassword(
        user: default, // assert that hasher don't need TUser
        hashedPassword: user.PasswordHash,
        providedPassword: context.Password);

    if (attempt == PasswordVerificationResult.Failed)
    {
        context.Fail("Login failed, password not match.");
        return;
    }

    var principal = await _cache.GetOrCreateAsync(normusername, async entry =>
    {
        var uid = user.Id;
        var ur = await dbContext.UserRoles
            .Where(r => r.UserId.Equals(uid))
            .Join(dbContext.Roles, r => r.RoleId, r => r.Id, (_, r) => r.Name)
            .ToListAsync();

        var options = context.HttpContext.RequestServices
            .GetRequiredService<IOptions<IdentityOptions>>().Value;

        // REVIEW: Used to match Application scheme
        var id = new ClaimsIdentity("Identity.Application",
            options.ClaimsIdentity.UserNameClaimType,
            options.ClaimsIdentity.RoleClaimType);
        id.AddClaim(new Claim(options.ClaimsIdentity.UserIdClaimType, $"{user.Id}"));
        id.AddClaim(new Claim(options.ClaimsIdentity.UserNameClaimType, user.UserName));
        id.AddClaim(new Claim(options.ClaimsIdentity.SecurityStampClaimType, user.SecurityStamp));
        foreach (var roleName in ur)
            id.AddClaim(new Claim(options.ClaimsIdentity.RoleClaimType, roleName));
        var value = new ClaimsPrincipal(id);

        entry.AbsoluteExpirationRelativeToNow = TimeSpan.FromMinutes(20);
        return value;
    });

    context.Principal = principal;
    context.Success();
}

缓存的5分钟/20分钟可以根据需求来修改。

这样就有一个似乎“完美”一点的Basic Authentication啦!

2019ICPC西安 C. Dirichlet k-th root

我不会狄利克雷卷积。输得彻底。

给出一个很直观的解法代码。

前提是理解好 f(n) = xf^k(n) = kx+b_{k,n}

那么

f^{2k}(n) = (f^k * f^k)(n) = 2f^k(n) + \sum_{d|n, d\not=1,d\not=n} f^k(d) f^k(n/d)

f^{k+1}(n) = (f^k * f)(n) = f^k(n) + f(n) + \sum_{d|n, d\not=1,d\not=n} f(d) f^k(n/d)

首先显然 f^k(1) = 1,假设已经求出 f(1) \cdots f(n-1) 里的所有值,以及可能取到的上标 k,那么就可以得到 g(n) = f^K(n) = Kx+b_{K,n}。我们可以从小往大递推 b_{i,n},然后求得 b_{K,n} 后得到 f(n),并解出 f^k(n) 以方便后续计算。

#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5+5;
const int mod = 998244353;
int g[maxn], f[maxn][64], b[maxn][64];
vector<int> d[maxn];
int n, k, kk[64], invk, tok;

int fpow(long long a, int k) {
    long long b = 1;
    for (; k; k >>= 1, a = a * a % mod)
        if (k & 1) b = b * a % mod;
    return b;
}

int main()
{
    scanf("%d %d", &n, &k);
    for (int i = 1; i <= n; i++)
        scanf("%d", &g[i]);
    for (int i = 2; i <= n; i++)
        for (int j = i+i; j <= n; j += i)
            d[j].push_back(i);
    for (int tmp = k; tmp; ) {
        kk[++tok] = tmp;
        if (tmp % 2 == 1) tmp--; else tmp /= 2;
    } // the used ks
    invk = fpow(k, mod-2);

    // find for each f(n)
    for (int i = 1; i <= tok; i++) f[1][i] = 1;
    for (int i = 2; i <= n; i++)
    {
        // if we know about f[1..n-1], let f[n] = x
        // so it can be proved that fk[n] = kx+b
        for (int j = tok-1; j; j--)
        {
            long long tmp = b[i][j+1];
            if (!(kk[j] & 1)) tmp <<= 1;
            int sec = (kk[j] & 1) ? tok : j+1;
            for (auto dd : d[i])
                tmp += 1ll * f[dd][j+1] * f[i/dd][sec] % mod;
            b[i][j] = tmp % mod;
        }

        // now we have g[n] = Kx+b, just...
        f[i][tok] = 1ll * (g[i] - b[i][1] + mod) * invk % mod;
        for (int j = tok-1; j; j--)
            f[i][j] = (1ll * kk[j] * f[i][tok] + b[i][j]) % mod;
        assert(f[i][1] == g[i]);
    }

    for (int i = 1; i <= n; i++)
        printf("%d ", f[i][tok]);
    return 0;
}

复杂度为 O(n\log n \log k)

另外给出一个道听途说的做法,当时完全没有想到的。

如果你手算能力够强,那么可以发现 b 是个类似于求和的东西,结果对 k 敏感,如果令 k = 998244353,那么求和结果为 0,对 g\operatorname{inv}k 次也可以得到 f

#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5+5;
const int mod = 998244353;
int g[maxn], f[maxn], n;

int fpow(long long a, int k) {
    long long b = 1;
    for (; k; k >>= 1, a = a * a % mod)
        if (k & 1) b = b * a % mod;
    return b;
}

void conv(int a[maxn], const int b[maxn]) {
    static int c[maxn];
    fill(c, c+n+1, 0);
    for (int i = 1; i <= n; i++)
        for (int j = 1; j * i <= n; j++)
            c[j*i] = (c[j*i] + 1ll * a[i] * b[j]) % mod;
    copy(c, c+n+1, a);
}

int main()
{
    int k; scanf("%d %d", &n, &k);
    for (int i = 1; i <= n; i++)
        scanf("%d", &g[i]);
    k = fpow(k, mod-2); f[1] = 1;
    for (; k; k >>= 1, conv(g, g))
        if (k & 1) conv(f, g);
    for (int i = 1; i <= n; i++)
        printf("%d ", f[i]);
    return 0;
}

不太会证明,能AC,复杂度为 O(n\log n \log \operatorname{inv} k)

2018ICPC北京 C. Pythagorean triple

数出所有满足 a^2+b^2=c^2c \le n 的正整数三元组。

首先,如果 (x,y,z) 是一个勾股数三元组,那么 (ax,ay,az) 也是个勾股数三元组。那么,不妨考虑 f(n) 为斜边长为 n 的勾股三元组数量,g(n) 为斜边长为 n 的本原勾股数,那么有

F(n) = \sum_{i=1}^n f(i) = \sum_{i=1}^n \sum_{d|i} g(d) = \sum_{d=1}^n \sum_{i=1}^{\lfloor n/d \rfloor} g(i) = \sum_{d=1}^n G(\left\lfloor\frac{n}{d}\right\rfloor)

考虑到所有本原勾股三元组可以表示为 (i^2-j^2, 2ij, i^2+j^2),其中 \gcd(i,j)=1,并且不能同时为奇数,那么有

\begin{aligned} G(n) & = \sum_{i 为奇} \sum_{j 为偶} [i^2+j^2 \le n] [\gcd(i,j)=1] \\ & = \sum_{d=1}^{\sqrt{n/2}} \mu(d) \sum_{i 为奇} \sum_{j 为偶} [i^2+j^2 \le n] [d|i] [d|j] \\ & = \frac{1}{2} \sum_{d=1}^{\sqrt{n/2}} \mu(d) \left( \sum_{i} \sum_{j} – \sum_{i 为奇} \sum_{j 为奇}\right ) [i^2+j^2 \le n] [d|i] [d|j] \\ & = \frac{1}{2} \sum_{d=1}^{\sqrt{n/2}} \mu(d) \left( \sum_{i} \left\lfloor\frac{\sqrt{n-i^2d^2}}{d}\right\rfloor – [d为奇] \sum_{i 为奇} \left\lfloor\frac{\frac{\sqrt{n-i^2d^2}}{d}+1}{2}\right\rfloor \right ) \end{aligned}

感觉写到这里一脸复杂度会超掉的样子,那么继续化简试试。

\begin{aligned} F(n) & = \sum_{c=1}^n G(\left\lfloor\frac{n}{c}\right\rfloor) \\ & = \frac{1}{2} \sum_{c=1}^n \sum_{d=1}^{\sqrt{n/2c}} \mu(d) \left( \sum_{i} \left\lfloor\frac{\sqrt{n/c-i^2d^2}}{d}\right\rfloor – [d为奇] \sum_{i 为奇} \left\lfloor\frac{\frac{\sqrt{n/c-i^2d^2}}{d}+1}{2}\right\rfloor \right ) \\ & = \frac{1}{2} \sum_{d=1}^{\sqrt{n/2}} \mu(d) \sum_{c=1}^n \left( \sum_{i} \left\lfloor\sqrt{\frac{n}{cd^2}-i^2}\right\rfloor – [d为奇] \sum_{i 为奇} \left\lfloor\frac{1}{2}\left(\sqrt{\frac{n}{cd^2}-i^2}+1\right)\right\rfloor \right ) \end{aligned}

那么可以发现最后那个对 c 求和的东西对外只和 n/d^2 有关系,对内可以数论分块,一脸时间复杂度 O(n^{3/4}) 的模样,复杂度够了,加个记忆化。冲冲冲~

#include <bits/stdc++.h>
using namespace std;
typedef long long lld;

inline int sqrtd(int x)
{
    int tot = sqrtl(x);
    while (tot * tot < x) tot++;
    while (tot * tot > x) tot--;
    return tot;
}

inline pair<lld,lld> c(int n)
{
    static unordered_map<int,pair<lld,lld>> _C;
    if (_C.count(n)) return _C[n];
    lld p1 = 0, p2 = 0; int m = sqrtd(n);

    for (int i = 1, j = m; i <= m; i++)
    {
        while (i * i + j * j > n) j--;
        p1 += j;
        if (i & 1) p2 += (j+1)/2;
    }

    return _C[n] = make_pair(p1, p2);
}

inline lld h(int n, bool odd)
{
    lld ans = 0;

    for (int L = 1, R; L <= n; L = R+1)
    {
        R = n / ( n / L );
        auto cur = c(n/L);
        ans += (R-L+1) * cur.first;
        if (odd) ans -= (R-L+1) * cur.second;
    }

    return ans;
}

int main()
{
    const int MAXN = 1e6+5;
    static int pri[MAXN], pcn;
    static char isnp[MAXN], miu[MAXN];
    miu[1] = 1;

    for (int i = 2; i < MAXN; i++)
    {
        if (!isnp[i]) pri[pcn++] = i, miu[i] = -1;
        for (int j = 0; j < pcn && i * pri[j] < MAXN; j++)
        {
            isnp[i * pri[j]] = 1;
            if (i % pri[j] == 0) break;
            miu[i * pri[j]] = -miu[i];
        }
    }

    int T, n;
    scanf("%d", &T);

    while (T--)
    {
        scanf("%d", &n);
        lld ans = 0;
        for (int d = 1; d * d <= n; d++)
            if (miu[d]) ans += miu[d] * h(n/d/d, d&1);
        printf("%lld\n", ans/2);
    }

    return 0;
}

2019CCPC秦皇岛 H. Houraisan Kaguya

题目链接

首先,如果你的群论足够好,你应该能理解到,

f(a,b) = \frac{\operatorname{ord}a}{\gcd(\operatorname{ord}a,\operatorname{ord}b)}

所以题目实际上是求

\sum_{i=1}^n \sum_{j=1}^n \frac{\operatorname{ord}a_i \times \operatorname{ord}a_j}{\gcd(\operatorname{ord}a_i,\operatorname{ord}a_j)^2}

然后现场赛到这里我就不会了(雾)

实际上,我们可以按莫比乌斯反演的套路想到,枚举 gcd。我们考虑到 \operatorname{ord}x 也一定是 p-1 的倍数,所以考虑这样的一个卷积

\begin{aligned} c_k &= \sum_{\gcd(i, j) = k} ijf_if_j \\ &=\sum_{i,j} if_i ~ jf_j [\gcd(i,j)=k] \end{aligned}

其中 f_i 表示 \operatorname{ord} a_x = i 的个数,那么考虑枚举 \gcd 的倍数,则有

\begin{aligned} {c’}_k &= \sum_{x}c_x[k|x] \\ &= \sum_{i,j} if_i ~ jf_j [k|\gcd(i,j)] \\ &= \sum_{i,j} if_i ~ jf_j [k|i][k|j] \\ &= \sum_{i} if_i [k|i] \sum_{j} jf_j [k|j] \end{aligned}

也就是需要作一个类似于

{c’}_k = \sum_{x} c_x[k|x]

的正变换和逆变换。考虑到 FWT 中那个类似于按位与卷积的东西,设计一个类似于高位前缀和的东西,就可以做了。

另外求 \operatorname{ord}x 也有一点注意。朴素的想法是先将 x 拉入模 p 非零元素乘法群的一个 {p_i}^{e_i} 阶子群中,也就是先设计 x_i = x^{\frac{p-1}{{p_i}^{e_i}}},然后求 x_i 在该子群中的周期。所以可以设计一个时间复杂度为 O(c(p-1) \log(p-1)) 的算法。在此过程中,计算的瓶颈在于 x_i 的计算,耗费了绝大部分的时间复杂度,那么我们可以设计一个分治的算法,将因子集合划分为两部分,前半部分将后半部分素因子子群的周期升阶到 1,后半部分同样操作前半部分,然后继续分治。类似于多项式多点求值的一个想法,时间复杂度降为 O(\log c(p-1) \log(p-1))

目前这份代码是CF上跑的最快的,756ms,用普通求阶时间为 1575ms。

#include <bits/stdc++.h>
using namespace std;
typedef long long lld;
typedef long double flt;
const int MAXN = 2e5+5, S = 10;
int n, w, fc[20], dcn;
lld mod, fp[20], zs[20];

inline lld mul(lld x, lld y, lld z) { lld t = flt(x) * y / z; lld ans = (x * y - t * z) % z; if (ans < 0) ans += z; return ans; }
inline lld mul(lld x, lld y) { lld t = flt(x) * y / mod; lld ans = (x * y - t * mod) % mod; if (ans < 0) ans += mod; return ans; }
inline void addeq(lld &x, lld y) { x = x+y - (x+y>=mod?mod:0); }
inline void muleq(lld &x, lld y) { lld t = flt(x) * y / mod; lld ans = (x * y - t * mod) % mod; if (ans < 0) ans += mod; x = ans; }
inline lld fpow(lld x, lld n, lld mod) { x %= mod; if (n == 1) return x; lld ret = 1; for (; n; n>>=1, x=mul(x,x,mod)) if (n&1) ret=mul(ret,x,mod); return ret; }
inline lld fpow(lld x, lld n) { x %= mod; if (n == 1) return x; lld ret = 1; for (; n; n>>=1, x=mul(x,x,mod)) if (n&1) ret=mul(ret,x,mod); return ret; }

namespace Decompose {
    int tol; lld factor[1000];

    bool millerRabin(lld n, lld base) {
        lld n2 = n-1, s = __builtin_ctzll(n2); n2 >>= s;
        lld t = fpow(base, n2, n);
        if (t == 1 || t == n-1) return true;
        for (s--; s >= 0; s--)
            if ((t=mul(t,t,n))==n-1)
                return true;
        return false;
    }

    bool isPrime(lld n) {
        static lld bases[12] = { 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37 };
        if (n <= 2) return false;
        for (int i = 0; i < 12 && bases[i] < n; i++)
            if (!millerRabin(n, bases[i])) return false;
        return true;
    }

    inline lld f(lld x, lld mod) {
        lld ans = 1; addeq(ans, mul(x, x, mod)); return ans;
    }

    void PollardRho(lld n) {
        if (n == 1) {
            return;
        } else if (isPrime(n)) {
            factor[tol++] = n;
        } else if (n & 1) {
            for (int i = 1; ; i++) {
                lld x = i, y = f(x, n), q = __gcd(y-x, n);
                while (q == 1) {
                    x = f(x, n), y = f(f(y, n), n);
                    q = __gcd((y-x+n)%n, n) % n;
                }
                if (q != 0 && q != n) {
                    PollardRho(q), PollardRho(n/q);
                    return;
                }
            }
        } else {
            while (!(n & 1)) factor[tol++] = 2, n >>= 1;
            PollardRho(n);
        }
    }

    inline void findfac(lld n) {
        tol = w = 0; PollardRho(n); map<lld,int> tj;
        for (int i = 0; i < tol; i++) tj[factor[i]]++;
        for (auto ss : tj) {
            fc[w] = ss.second, fp[w] = ss.first;
            zs[w] = 1; for (int i = 0; i < fc[w]; i++) zs[w] *= fp[w];
            w++;
        }
    }
}

using Decompose::findfac;

lld getOrd(lld a) {
    lld res = mod-1;
    for (int i = 0; i < w; i++) {
        lld qwq = fpow(a, res/=zs[i]);
        while (qwq != 1) qwq = fpow(qwq, fp[i]), res *= fp[i];
    }
    return res;
} // Time complexity: O(c(n)log(n))

struct Node { int l, r; lld val; };

lld getOrdFast(lld a) {
    static Node Q[100000]; int front = 0, rear = 0;
    lld res = 1; Q[rear++] = Node { 0, w-1, a };
    while (front < rear) {
        Node f = Q[front++];
        if (f.l == f.r) {
            for (lld bs = f.val; bs != 1; bs = fpow(bs, fp[f.l]))
                res *= fp[f.l];
        } else if (f.l < f.r) {
            int mid = (f.l+f.r)>>1;
            lld prod1 = 1, prod2 = 1;
            for (int i = f.l; i <= mid; i++) prod2 *= zs[i];
            for (int i = mid+1; i <= f.r; i++) prod1 *= zs[i];
            Q[rear++] = Node { f.l, mid, fpow(f.val, prod1) };
            Q[rear++] = Node { mid+1, f.r, fpow(f.val, prod2) };
        }
    }
    return res;
} // Time complexity: O(log(c(n))log(n))

lld di[MAXN], invdi2[MAXN]; int qp[20];
lld a[MAXN], g[MAXN];

void genFac(int wi, lld fac, int ids) {
    if (wi == w) {
        di[ids] = fac;
        invdi2[ids] = fac==1?1:mod-mod/fac;
        muleq(invdi2[ids], invdi2[ids]);
    } else {
        lld pr = 1;
        for (int i = 0; i <= fc[wi]; i++, pr *= fp[wi])
            genFac(wi+1, fac*pr, ids+i*qp[wi]);
    }
}

void addResult(lld ord)
{
    int id = 0; lld n = ord;

    for (int i = 0; i < w; i++)
    {
        int r = 0;
        while (n % fp[i] == 0) n /= fp[i], r++;
        id += r * qp[i];
    }

    addeq(g[id], ord);
}

void solve()
{
    findfac(mod - 1);

    qp[0] = 1;
    for (int i = 1; i < w; i++)
        qp[i] = qp[i-1] * (fc[i-1] + 1);
    dcn = qp[w-1] * (fc[w-1] + 1);
    genFac(0, 1, 0);

    for (int i = 1; i <= n; i++)
        addResult(getOrdFast(a[i]));
    for (int i = 0; i < w; i++)
        for (int j = dcn-1; j >= 0; j--)
            if (j / qp[i] % (fc[i]+1) != fc[i])
                addeq(g[j], g[j+qp[i]]);
    for (int i = 0; i < dcn; i++)
        muleq(g[i], g[i]);
    for (int i = 0; i < w; i++)
        for (int j = 0; j < dcn; j++)
            if (j / qp[i] % (fc[i]+1) != fc[i])
                addeq(g[j], mod-g[j+qp[i]]);
    lld ans = 0;
    for (int i = 0; i < dcn; i++)
        addeq(ans, mul(g[i], invdi2[i]));
    printf("%lld\n", ans);
}

int main()
{
    srand(time(NULL));
    scanf("%d %lld", &n, &mod);
    for (int i = 1; i <= n; i++)
        scanf("%lld", &a[i]);
    solve();
    return 0;
}

2019 ICPC上海网络赛部分题解

C. Triple

考虑分情况不同做法。当 n \le 1000 时,枚举 A_i, B_j,则根据不等式解出满足条件的 C_k \in [|A_i-B_j|,A_i+B_j]。当 n > 1000,考虑枚举不合法的方案数,即形如 A_i + B_j < C_k 的方案数,不等式左边可以通过FFT取得,然后前缀和再枚举 C_k 容斥掉即可。

大力施加常数优化,AC仅需1256ms啦啦啦~

#include <bits/stdc++.h>
using namespace std;
typedef long long lld;
typedef complex<double> cplx;
int revs[1<<19], A[300010]; cplx wmks[1<<19];
const double PI = acos(-1.0);

void dft(cplx a[], int DFT, int N) {
    int *rev = revs + N;
    for (int i = 0; i < N; i++)
        if (i < revs[N+i]) swap(a[i], a[rev[i]]);
    for (int m = 2, m2 = 1; m <= N; m <<= 1, m2 <<= 1) {
        cplx *wmk = wmks + m + (DFT==-1 ? m2 : 0), u, t;
        for (int k = 0; k < N; k += m)
            for (int j = 0; j < m2; j++)
                t = wmk[j] * a[k+j+m2], u = a[k+j],
                a[k+j] = u + t, a[k+j+m2] = u - t;
    }
    if (DFT == -1) for (int i = 0; i < N; i++) a[i] /= N;
}

int a[100010], b[100010], c[100010];

lld solveWithFFT(int n)
{
    static cplx f[262144], g[262144], h[262144];
    static lld bc[262144], ac[262144], ab[262144];
    int ma = 0, mb = 0, mc = 0, md, len = 1; cplx tmp, tmp2;
    for (int i = 0; i < n; i++) ma = max(ma, a[i]), mb = max(mb, b[i]), mc = max(mc, c[i]);
    md = max(ma+mb, max(mb+mc, mc+ma)) + 2; while (len < md) len <<= 1;
    for (int i = 0; i < len; i++) f[i] = g[i] = h[i] = 0;
    for (int i = 0; i < n; i++) f[a[i]] += 1, g[b[i]] += 1, h[c[i]] += 1;
    dft(f, 1, len), dft(g, 1, len), dft(h, 1, len);
    for (int i = 0; i < len; i++) tmp = f[i] * g[i], tmp2 = g[i] * h[i], g[i] = f[i] * h[i], f[i] = tmp2, h[i] = tmp;
    dft(f, -1, len), dft(g, -1, len), dft(h, -1, len);
    for (int i = 0; i < len; i++) bc[i] = f[i].real()+0.2, ac[i] = g[i].real()+0.2, ab[i] = h[i].real()+0.2;
    for (int i = 1; i < len; i++) bc[i] += bc[i-1], ac[i] += ac[i-1], ab[i] += ab[i-1];
    lld ans = 1LL * n * n * n;

    for (int i = 0; i < n; i++)
    {
        if (c[i] <= len) ans -= ab[c[i]-1];
        if (b[i] <= len) ans -= ac[b[i]-1];
        if (a[i] <= len) ans -= bc[a[i]-1];
    }

    return ans;
}

inline int abs(int x) { return x < 0 ? -x : x; }

lld solveWithBruteForce(int n)
{
    for (int i = 0; i <= 200001; i++) A[i] = 0;
    for (int i = 0; i < n; i++) A[c[i]]++;
    for (int i = 1; i <= 200001; i++) A[i] += A[i-1];
    lld ans = 0;
    for (int i = 0; i < n; i++)
        for (int j = 0; j < n; j++)
            ans += A[a[i]+b[j]] - A[max(abs(a[i]-b[j])-1,0)];
    return ans;
}

void solve(int cas)
{
    int n; scanf("%d", &n);
    for (int i = 0; i < n; i++) scanf("%d", &a[i]);
    for (int i = 0; i < n; i++) scanf("%d", &b[i]);
    for (int i = 0; i < n; i++) scanf("%d", &c[i]);
    printf("Case #%d: %lld\n", cas, n <= 1000 ? solveWithBruteForce(n) : solveWithFFT(n));
}

int main()
{
    for (int N = 2, N2 = 1; N <= (1<<18); N <<= 1, N2 <<= 1)
    {
        int *rev = revs + N; cplx *wmk = wmks + N;
        for (int i = 0; i < N; i++)
            rev[i] = (rev[i>>1]>>1)|((i&1)?N2:0);
        for (int i = 0; i < N2; i++)
            wmk[i] = cplx(cos(PI*i/N2), sin(PI*i/N2));
        for (int i = 0; i < N2; i++)
            wmk[i+N2] = cplx(wmk[i].real(), -wmk[i].imag());
    }

    int T; scanf("%d", &T);
    for (int i = 1; i <= T; i++) solve(i);
    return 0;
}

D. Counting Sequences I

首先发现是几个本质不同的方案的排列,于是想到打表。需要一个快速出结果的剪枝。

#include <bits/stdc++.h>
using namespace std;
typedef long long lld;
const int MOD = 1e9+7;
int n, jspx[3010], val[3010];
lld fac[3010], inv[3010], invs[3010], ans;

void dfs(int cur, int last, lld mul, int sum)
{
    if (cur == n)
    {
        if (mul == sum)
        {
            for (int i = 1; i <= n; i++)
                if (jspx[i])
                    printf("%dx%d ", jspx[i], i);
            printf("\n");
            lld tot = fac[n];
            for (int i = 1; i <= n; i++)
                tot = tot * invs[jspx[i]] % MOD;
            ans = (ans + tot) % MOD;
        }
    }
    else if (mul * pow(last, n-cur) <= sum + last * (n - cur))
    {
        for (int i = last; i <= n; i++)
        {
            val[cur] = i;
            jspx[i]++;
            dfs(cur + 1, i, mul * i, sum + i);
            jspx[i]--;
        }
    }
}

int main()
{
    fac[0] = fac[1] = invs[0] = invs[1] = inv[1] = 1;
    for (int i = 2; i < 3010; i++)
        fac[i] = i * fac[i-1] % MOD,
        inv[i] = (MOD-MOD/i) * inv[MOD%i] % MOD,
        invs[i] = invs[i-1] * inv[i] % MOD;
    scanf("%d", &n);
    dfs(0, 1, 1, 0);
    printf("ANS = %lld\n", ans);
    return 0;
}

E. Counting Sequences II

不知道题解在写什么鬼东西……

考虑 [1,m] 中的每个数字 t 的排列型生成函数。如果 t 为奇数,则其生成函数为 f(x) = 1 + \frac{x}{1!} + \frac{x^2}{2!} + \cdots = e^x;否则只能出现偶数次,则为 g(x) = 1 + \frac{x^2}{2!} + \frac{x^4}{4!} + \cdots = \frac{e^x + e^{-x}}{2} = \ch(x)

那么答案相当于求

\left(\frac{e^x+e^{-x}}{2}\right)^{\left\lfloor \frac{m}{2} \right\rfloor} \left(e^x\right)^{\left\lceil \frac{m}{2} \right\rceil}

\frac{x^n}{n!} 项系数。注意到 \left\lceil \frac{m}{2} \right\rceil = \left\lfloor \frac{m}{2} \right\rfloor + (m \mod 2),令 k = \left\lfloor \frac{m}{2} \right\rfloor,则可以化简为

\frac{1}{2^k} \sum_{s=0}^k C_{k}^{s} e^{(2s+(m\mod 2))x}

即最后答案为

\frac{1}{2^k} \sum_{s=0}^k \frac{k! (2s+(m\mod 2))^n}{s! (k-s)!}

#include <bits/stdc++.h>
using namespace std;
typedef long long lld;
const int MOD = 1e9+7, MAXN = 2e5+5;
int fac[MAXN], inv[MAXN], invs[MAXN];

int fpow(lld a, lld k)
{
    lld b = 1; k %= MOD-1;
    for (; k; k >>= 1, a = a * a % MOD)
        if (k & 1) b = b * a % MOD;
    return b;
}

void solve()
{
    lld ans = 0, n; int m, k;
    scanf("%lld %d", &n, &m); k = m / 2;
    for (int s = 0; s <= k; s++)
        ans += 1ll * fac[k] * invs[s] % MOD * invs[k-s] % MOD * fpow(2*s+(m&1), n) % MOD;
    ans = ans % MOD * fpow((MOD+1)/2, k) % MOD;
    printf("%lld\n", ans);
}

int main()
{
    fac[0] = fac[1] = inv[1] = 1;
    invs[0] = invs[1] = 1;
    for (int i = 2; i < MAXN; i++)
        fac[i] = 1ll * i * fac[i-1] % MOD,
        inv[i] = 1ll * (MOD-MOD/i) * inv[MOD%i] % MOD,
        invs[i] = 1ll * invs[i-1] * inv[i] % MOD;
    int T; scanf("%d", &T);
    while (T--) solve();
    return 0;
}

H. Luhhy’s Matrix

考虑让每个矩阵只被运算一次,那么可以使用类似于分块的想法设立分界线。

#include <bits/stdc++.h>
using namespace std;
typedef long long lld;
const int MAXN = 2e5+5;
unsigned lastans, seed, pow17[16], pow19[16];
char cost[MAXN];

struct matrix
{
    char item[16][16];
    int col[16], row[16];
    // let row[k] = item[i][k], col[k] = item[k][i]

    void norm()
    {
        for (int k = 0; k < 16; k++)
        {
            row[k] = col[k] = 0;
            for (int i = 0; i < 16; i++)
                row[k] = (row[k] << 1) | (item[i][k] & 1),
                col[k] = (col[k] << 1) | (item[k][i] & 1);
        }
    }

    matrix operator*(const matrix &b) const
    {
        matrix ans;
        for (int i = 0; i < 16; i++)
            for (int j = 0; j < 16; j++)
                ans.item[i][j] = cost[col[i]&b.row[j]];
        ans.norm();
        return ans;
    }

    static matrix getOne()
    {
        matrix ans; memset(ans.item, 0, sizeof(ans.item));
        for (int i = 0; i < 16; i++) ans.item[i][i] = 1;
        ans.norm(); return ans;
    }

    unsigned getAns() const
    {
        unsigned ans = 0;
        for (int i = 0; i < 16; i++)
            for (int j = 0; j < 16; j++)
                ans += item[i][j] * pow17[i] * pow19[j];
        return ans;
    }

    void output() const
    {
        for (int i = 0; i < 16; i++)
            for (int j = 0; j < 16; j++)
                printf("%d%c", int(item[i][j]), " \n"[j==15]);
    }

    static matrix generate()
    {
        matrix ans;
        seed ^= lastans;
        for (int i = 0; i < 16; i++)
        {
            seed ^= seed * seed + 15;
            for (int j = 0; j < 16; j++)
                ans.item[i][j] = (seed >> j) & 1;
        }
        ans.norm();
        return ans;
    }
} Q[MAXN], tail;

void solve()
{
    int front = 0, rear = 0, cut = 0, n; lastans = 0;
    scanf("%d", &n); tail = matrix::getOne();

    while (n--)
    {
        int t; scanf("%d %u", &t, &seed);

        if (t == 1)
        {
            Q[rear++] = matrix::generate();
            if (rear == front+1) tail = matrix::getOne(), cut = rear;
            else tail = Q[rear-1] * tail;
        }
        else
        {
            front++;

            if (front == cut)
            {
                cut = rear;
                for (int i = rear-2; i >= front; i--) Q[i] = Q[i+1] * Q[i];
                tail = matrix::getOne();
            }
        }

        lastans = (front == rear) ? 0 : (tail * Q[front]).getAns();
        printf("%u\n", lastans);
    }
}

int main()
{
    for (int sz2 = 1; sz2 < 65536; sz2 <<= 1)
        for (int i = 0; i < sz2; i++)
            cost[sz2+i] = cost[i] ^ 1;
    pow17[0] = pow19[0] = 1;
    for (int i = 1; i < 16; i++)
        pow17[i] = pow17[i-1] * 17,
        pow19[i] = pow19[i-1] * 19;
    int T; scanf("%d", &T);
    while (T--) solve();
    return 0;
}

K. Peekaboo

圆上整点。直接分解 a,b 然后暴力枚举点对。看题解以为卡了64倍常数,结果交了就过了(大雾)

#include <bits/stdc++.h>
using namespace std;
typedef long long lld;
typedef pair<int,int> pear;
#define F first
#define S second
inline lld sqr(int a) { return 1ll * a * a; }
lld gcd(lld a, lld b) { return b ? gcd(b, a%b) : a; }

void gaussInteger(lld R, vector<pear> &tot)
{
    auto check = [&tot] (lld r, lld rr) -> void
    {
        if (r == 1 || r % 4 != 1) return;
        for (int n = 1; n*n*2 < r; n++)
        {
            int m = int(sqrt(r-n*n)+1e-5);
            if (m * m + n * n != r) continue;
            if (gcd(m,n) != 1 || m <= n) continue;
            tot.emplace_back(rr*(m*m-n*n), 2*n*m*rr);
        }
    };

    for (int i = 1; i * i <= R; i++)
    {
        if (R % i) continue;
        check(R/i, i);
        if (i * i != R) check(i, R/i);
    }
}

void fillOver(vector<pear> &proc, int r)
{
    int cnt = proc.size();
    proc.resize(cnt*8+4);
    for (int i = 0; i < cnt; i++)
        proc[i+cnt] = pear(proc[i].S, proc[i].F);
    for (int i = 0; i < 2*cnt; i++)
        proc[i+cnt*2] = pear(proc[i].F, -proc[i].S),
        proc[i+cnt*4] = pear(-proc[i].F, proc[i].S),
        proc[i+cnt*6] = pear(-proc[i].F, -proc[i].S);
    proc[cnt*8] = pear(0, r);
    proc[cnt*8+1] = pear(r, 0);
    proc[cnt*8+2] = pear(0, -r);
    proc[cnt*8+3] = pear(-r, 0);
}

pair<pear,pear> ans[100050];

void solve()
{
    vector<pear> possA, possB;
    int a, b, c, tot = 0;
    scanf("%d %d %d", &a, &b, &c);
    gaussInteger(a, possA), gaussInteger(b, possB);
    fillOver(possA, a), fillOver(possB, b);

    for (auto A : possA) for (auto B : possB)
    {
        lld r2 = sqr(c), r3 = sqr(A.F-B.F)+sqr(A.S-B.S);
        if (r2 == r3) ans[tot++] = make_pair(A, B);
    }

    sort(ans, ans+tot);
    printf("%d\n", tot);
    for (int i = 0; i < tot; i++) printf("%d %d %d %d\n", ans[i].F.F, ans[i].F.S, ans[i].S.F, ans[i].S.S);
}

int main()
{
    int T; scanf("%d", &T);
    while (T--) solve();
    return 0;
}

HDU6414 带劲的多项式

考虑形如

f(x) = \prod_{i=1}^r (x-\lambda_i)^{l_i}

形式的多项式,它的一阶导函数为

f'(x) = \left( \sum_{i=1}^{r}l_i(x-\lambda_i) \right) \left(\prod_{i=1}^r (x-\lambda_i)^{l_i -1}\right)

则可以发现它们之间有

\gcd(f(x), f'(x)) = \prod_{i=1}^r (x-\lambda_i)^{l_i-1}

则有

\frac{f(x)}{\gcd(f(x), f'(x))} = \prod_{i=1}^r (x-\lambda_i)

由于根重数不同,每个根的因子会在不同时刻从上式消失。每次将 \gcd(f, f’) 当作新的 f 进行迭代即可。

另外,STL好慢啊233。

#include <bits/stdc++.h>
using namespace std;
typedef long long lld;
const int MOD = 998244353;
typedef vector<int> Poly;

inline lld fpow(lld a, int k)
{
    lld b = 1;
    for (; k; k>>=1, a=a*a%MOD)
        if (k&1) b=b*a%MOD;
    return b;
}

void norm(Poly &a)
{
    lld qwq = fpow(a.back(), MOD-2);
    for (auto &x : a) x = x * qwq % MOD;
}

Poly diff(const Poly &a)
{
    Poly _a = Poly(a.size() - 1);
    for (int i = 0; i < _a.size(); i++)
        _a[i] = a[i+1] * (i+1LL) % MOD;
    return _a;
}

void divide(const Poly &a, const Poly &p, Poly &q, Poly &r)
{
    assert(p.size() >= 1 && p.back() != 0);
    q.clear(), r = a; int P = p.size();
    lld pq = fpow(p.back(), MOD-2);

    for (int i = a.size()-1; i >= P-1; i--)
    {
        q.push_back(r[i] * pq % MOD);
        for (int j = 0; j < P; j++)
            r[i-P+1+j] = (r[i-P+1+j] + MOD - 1LL * r[i] * pq % MOD * p[j] % MOD) % MOD;
    }

    reverse(q.begin(), q.end());
    while (q.back() == 0) q.pop_back();
    while (r.back() == 0) r.pop_back();
}

Poly gcd(Poly a, Poly b)
{
    Poly c, d;
    while (b.size()) divide(a, b, c, d), a = b, b = d;
    norm(a);
    return a;
}

void solve()
{
    int n; scanf("%d", &n); vector<pair<int,int>> ans;
    Poly orig(n+1, 0), dif, cbs, tmp, nil, nul;
    for (int i = 0; i <= n; i++) scanf("%d", &orig[i]);
    norm(orig); dif = diff(orig), tmp = gcd(orig, dif);
    divide(orig, tmp, cbs, nil); orig = tmp, norm(orig);

    for (int l = 1; dif.size(); l++)
    {
        dif = diff(orig), tmp = gcd(orig, dif);
        divide(orig, tmp, nul, nil);

        if (nul != cbs)
        {
            // an root appears!
            Poly a, b;
            divide(cbs, nul, b, a);
            assert(b.size() == 2);
            ans.emplace_back((MOD-b[0])%MOD, l);
            cbs = nul;
        }

        orig = tmp, norm(orig);
    }

    printf("%d\n", int(ans.size()));
    for (auto p : ans) printf("%d %d\n", p.first, p.second);
}

int main()
{
    int T; scanf("%d", &T);
    while (T--) solve();
    return 0;
}

2019 ICPC南昌网络赛部分题解

A. Enju with Math Problem

给定欧拉函数表中连续100项,求是否存在或从哪个点开始。

考虑和去年EC筛 \mu^2(n) 一样的思路,检查 5, 7, 9, 11, 13, 17, 23 的倍数所在位置,利用中国剩余定理求解可能位置,然后利用区间筛筛出该区间的欧拉函数,对比。

#include <bits/stdc++.h>
using namespace std;
typedef long long lld;
const int MAXN = 2e4+5;
int pri[MAXN], pcn; bool isnp[MAXN];
int test[100], qaq[100], phi[100];

void init_prime()
{
    for (int i = 2; i < MAXN; i++)
    {
        if (!isnp[i]) pri[pcn++] = i;
        for (int j = 0; j < pcn; j++)
        {
            if (i * pri[j] >= MAXN) break;
            isnp[i * pri[j]] = true;
            if (i % pri[j] == 0) break;
        }
    }
}

int check(int y)
{
    phi[0] = qaq[0] = y;
    for (int i = 1; i < 100; i++)
        phi[i] = qaq[i] = qaq[i-1] + 1;

    int stt = qaq[0];
    if (stt <= 0 || qaq[99] > 150000000) return -1;

    for (int i = 0; i < pcn; i++)
    {
        int ft = ((stt - 1) / pri[i] + 1) * pri[i] - stt;
        for (int j = ft; j < 100; j += pri[i])
        {
            phi[j] = phi[j] / pri[i] * (pri[i] - 1);
            while (qaq[j] % pri[i] == 0) qaq[j] /= pri[i];
        }
    }

    for (int i = 0; i < 100; i++)
        if (qaq[i] > 1)
            phi[i] = phi[i] / qaq[i] * (qaq[i] - 1);
    for (int i = 0; i < 100; i++)
        if (test[i] != phi[i])
            return -1;
    return stt;
}

lld extgcd(lld a, lld b, lld &x, lld &y)
{
    if (!b) return x = 1, y = 0, a;
    lld d = extgcd(b, a % b, y, x);
    return y -= (a / b) * x, d;
}

pair<lld,lld> modeqs(lld b[], const int w[], int k)
{
    lld bj = b[0], wj = w[0], x, y, d;
    for (int i = 1; i < k; i++)
    {
        b[i] %= w[i]; d = extgcd(wj, w[i], x, y);
        if ((bj - b[i]) % d != 0) return make_pair(0LL, -1LL);
        x = x * (b[i] - bj) / d % (w[i] / d), y = wj / d * w[i];
        bj = ((bj + x * wj) % y + y) % y, wj = y;
    }
    return make_pair(bj, wj);
}

int bs[8][25]; lld b[8];
const int m[8] = { 6, 4, 6, 10, 12, 16, 18, 22 };
const int W[8] = { 9, 5, 7, 11, 13, 17, 19, 23 };
vector<int> poss[8];

int dfs(int i)
{
    if (i == 8)
    {
        auto getAns = modeqs(b, W, 8);
        if (getAns.first == 0) return -1;
        return check(getAns.first);
    }
    else
    {
        for (auto v : poss[i])
        {
            b[i] = W[i] - v;
            int t = dfs(i+1);
            if (t > 0) return t;
        }

        return -1;
    }
}

int solve()
{
    memset(bs, 0, sizeof(bs));
    for (int i = 0; i < 8; i++)
        poss[i].clear();

    for (int i = 0; i < 100; i++)
    {
        scanf("%d", &test[i]);
        for (int j = 0; j < 8; j++)
            if (test[i] % m[j] == 0)
                bs[j][i % W[j]]++;
    }

    for (int i = 0; i < 8; i++)
        for (int j = 0; j < W[i]; j++)
            if (bs[i][j] >= 100 / W[i])
                poss[i].push_back(j);
    return dfs(0);
}

int main()
{
    init_prime();
    int T; scanf("%d", &T);

    while (T--)
    {
        int ans = solve();
        if (ans == -1) printf("NO\n");
        else printf("YES\n%d\n", ans);
    }

    return 0;
}

D. Interesting Series

考虑 F_n = 1 + a + a^2 + \cdots + a^{n-1} = \frac{a^n – 1}{a-1}

那么所求相当于是

Answer(K) = \frac{\sum_{s \in K} a^s – C_{N}^{|K|}}{a-1}

考虑构造多项式

f(n) = \prod_{i=1}^n (1 + a^s x)

那么所求 \sum_{s \in K} a^s 就是 f(n)x^{|K|} 项系数之和。

分治FFT计算即可。注意到模数比较小,所以考虑用double或者long double跑FFT,但是要注意预处理精度(感谢Claris姐姐呜呜呜),不能多次乘。

#include <bits/stdc++.h>
typedef long long lld;
const double PI = acos(-1.0);
const int MOD = 100003;
typedef std::complex<double> cplx;
int fac[MOD], inv[MOD], invs[MOD];

void dft(cplx a[], int DFT, int N) {
    static int rev[1<<20]; static cplx w[1<<20], ww[1<<20];
    for (int i = 0; i < N; i++) {
        rev[i] = (rev[i>>1]>>1)|((i&1)?N>>1:0);
        if (i < rev[i]) swap(a[i], a[rev[i]]);
        w[i] = cplx(cos(PI*i/N), sin(PI*i/N));
        ww[i] = cplx(w[i].real(), -w[i].imag());
    }
    for (int d = 0, ctz = __builtin_ctz(N); (1<<d) < N; d++) {
        cplx u, t; int m2 = 1<<d, m = m2<<1;
        for (int k = 0; k < N; k += m)
            for (int j = 0; j < m2; j++)
                t = (DFT==1?w:ww)[j<<(ctz-d)] * a[k+j+m2], u = a[k+j],
                a[k+j] = u + t, a[k+j+m2] = u - t;
    }
    if (DFT == -1) for (int i = 0; i < N; i++) a[i] /= N;
}

lld fpow(lld a, lld k) {
    lld b = 1;
    for (; k; k>>=1, a=a*a%MOD)
        if (k&1) b=b*a%MOD;
    return b;
}

int t[131072], ss[MOD], a;

void solve(int l, int r, int L, int R)
{
    if (R - L <= 31)
    {
        int len = R-L+1, sz = 1;
        for (int i = 0; i < len; i++) t[L+i] = !i;
        for (int w = l; w <= r; w++, sz++)
            for (int i = sz; i > 0; i--)
                (t[L+i] += t[L+i-1] * fpow(a, ss[w]) % MOD) %= MOD;
    }
    else
    {
        int mid = (L+R)>>1, m = (l+r)>>1, len=R-L+1;
        solve(l, m, L, mid), solve(m+1, r, mid+1, R);
        static cplx a[131072], b[131072];
        for (int i = 0; i < len/2; i++)
            a[i] = t[L+i], b[i] = t[mid+1+i],
            a[len/2+i] = b[len/2+i] = 0;
        dft(a, 1, len), dft(b, 1, len);
        for (int i = 0; i < len; i++) a[i] *= b[i];
        dft(a, -1, len);
        for (int i = 0; i < len; i++)
            t[L+i] = lld(a[i].real()+0.5) % MOD;
    }
}

int main()
{
    inv[0] = fac[0] = invs[0] = invs[1] = inv[1] = fac[1] = 1;
    for (int i = 2; i < MOD; i++)
        fac[i] = 1ll * fac[i-1] * i % MOD,
        inv[i] = 1ll * (MOD-MOD/i) * inv[MOD%i] % MOD,
        invs[i] = 1ll * invs[i-1] * inv[i] % MOD;

    int N, Q, s;
    scanf("%d %d %d", &N, &a, &Q);
    for (int i = 1; i <= N; i++)
        scanf("%d", &ss[i]);
    int sz = 30, st = 32;
    while (sz < N) sz <<= 1, st <<= 1;
    solve(1, N, 0, st-1);

    while (Q--)
    {
        scanf("%d", &s);
        lld ans = t[s] - 1ll * fac[N] * invs[s] * invs[N-s] % MOD;
        ans = (ans + MOD) * inv[a-1] % MOD;
        printf("%lld\n", ans);
    }

    return 0;
}

话说回来补题的时候一直没过竟然是因为……输出答案时那个组合数写错了呜呜呜,变量意义改变不改名,调试两行泪。

H. The Nth Item

考虑作出 f(n+x) = a \times f(n) + b \times f(n+1) 的表达式,然后分段打表。

#include <cstdio>
using namespace std;
typedef long long lld;
const int MAXN = 5e5+5;
const int MOD = 998244353;
int trans[1002][2], f[MAXN][2];

int main()
{
    trans[0][0] = trans[1][1] = f[0][1] = 1;

    for (int i = 2; i <= 1001; i++)
    {
        trans[i][0] = 2ll * trans[i-1][1] % MOD;
        trans[i][1] = (trans[i-1][0] + 3ll * trans[i-1][1]) % MOD;
    }

    for (int i = 1; i < MAXN; i++)
    {
        f[i][0] = (1ll * trans[1000][0] * f[i-1][0] + 1ll * trans[1000][1] * f[i-1][1]) % MOD;
        f[i][1] = (1ll * trans[1001][0] * f[i-1][0] + 1ll * trans[1001][1] * f[i-1][1]) % MOD;
    }

    int Q; lld N; scanf("%d %lld", &Q, &N);
    lld lastAns = N, tot = 0; N = 0;

    for (int i = 0; i < Q; i++)
    {
        lld now = lastAns ^ N; N = now; now %= (MOD-1)/2;
        lld ans = (1ll * trans[now%1000][0] * f[now/1000][0] + 1ll * trans[now%1000][1] * f[now/1000][1]) % MOD;
        tot ^= ans; lastAns = ans * ans;
    }

    printf("%lld\n", tot);
    return 0;
}

2019 ICPC徐州网络赛部分题解

E. XKC’s basketball team

就在区间 [i+1,R] 找最右边大于 a_i+m 的数的位置。然后考虑到线段树去查询,先查右子树,再查左子树,然后写着写着就成了二叉树了……

#include <bits/stdc++.h>
using namespace std;
const int MAXN = 5e5+5;
int mmax[MAXN*4], a[MAXN];
#define lson i<<1
#define rson i<<1|1
int n, m;

int query(int x, int y, int L, int R, int i)
{
    if (mmax[i] < x || R <= y) return -1; else if (L == R) return L;
    int mid = (L+R)>>1; int q = query(x, y, mid+1, R, rson);
    if (~q) return q; else return query(x, y, L, mid, lson);
}

void build(int L, int R, int i)
{
    if (L == R) { scanf("%d", &mmax[i]); a[L] = mmax[i]; return; }
    int mid = (L+R)>>1; build(L,mid,lson), build(mid+1,R,rson);
    mmax[i] = max(mmax[lson], mmax[rson]);
}

int main()
{
    scanf("%d %d", &n, &m);
    build(1, n, 1);
    for (int i = 1; i <= n; i++)
    {
        int loc = query(a[i]+m, i, 1, n, 1);
        if (loc > 0) loc = loc - i - 1;
        printf("%d%c", loc, " \n"[i==n]);
    }
    return 0;
}

H. function

想不到,也只能抄抄题解过日子了。

#include <bits/stdc++.h>
using namespace std;
typedef long long lld;
const int MAXN = 2e5+5;
const int MOD = 998244353;
const int INV2 = (MOD+1)/2;
int pri[MAXN], pcn, sp[MAXN];
int m, Sqrt, id1[MAXN], id2[MAXN];
bool isnp[MAXN]; lld g[MAXN], h[MAXN], w[MAXN], n;

int main()
{
    scanf("%lld", &n);

    for (int i = 2; i < MAXN; i++)
    {
        if (!isnp[i])
        {
            pcn++; pri[pcn] = i;
            sp[pcn] = (sp[pcn-1] + i) % MOD;
        }

        for (int j = 1; j <= pcn; j++)
        {
            if (i * pri[j] >= MAXN) break;
            isnp[i * pri[j]] = true;
            if (i % pri[j] == 0) break;
        }
    }

    m = 0, Sqrt = sqrt(n);

    for (lld L = 1, R; L <= n; L = R+1)
    {
        R = n / (n / L), w[++m] = n / L;
        (w[m] <= Sqrt ? id1[w[m]] : id2[R]) = m;
        g[m] = (w[m] - 1) % MOD;
        h[m] = (w[m] % MOD + 2) * (w[m] % MOD - 1) % MOD * INV2 % MOD;
    }

    for (int j = 1; j <= pcn; j++)
    {
        for (int i = 1, d; i <= m && pri[j] <= w[i]/pri[j]; i++)
        {
            d = (w[i]/pri[j]<=Sqrt) ? id1[w[i]/pri[j]] : id2[n/(w[i]/pri[j])];
            g[i] -= g[d] - j + 1;
            h[i] -= pri[j] * (h[d] - sp[j-1]) % MOD;
        }
    }

    lld ans = 0;

    for (int i = 1; i <= pcn; i++)
    {
        if (n / pri[i] >= pri[i])
        {
            for (lld p = pri[i]; p <= n; p *= pri[i])
            {
                lld fk = n/p%MOD;
                ans += (n+1)%MOD*fk%MOD;
                ans -= fk*(fk+1)%MOD*INV2%MOD*p%MOD;
            }
        }
    }

    for (int i = 1; i < Sqrt; i++)
    {
        lld sp2 = h[i] - h[i+1], sp1 = g[i] - g[i+1];
        ans += (n+1)%MOD*i%MOD*sp1%MOD;
        ans -= i*(i+1ll)/2%MOD*sp2%MOD;
    }

    ans %= MOD; if (ans < 0) ans += MOD;
    printf("%lld\n", ans);
    return 0;
}

J. Random Access Iterator

假设深度为树高的叶节点集合为 R,则相当于从起点按某种移动方式转移到 R 中的概率。考虑 dp[u] 为从 u 转移不到 R 中的概率,那么有

dp[u] = \frac{1}{d(u)} \sum_{(u,v)\in T} dp[v]

同时注意深度为树高的叶节的概率为 0,其他叶节点概率为 1,即DP初值。然后两遍DFS解决。

#include <vector>
#include <cstdio>
using namespace std;
typedef long long lld;
const int MAXN = 1e6+5;
const int MOD = 1e9+7;
vector<int> G[MAXN];
int maxdep;

lld fpow(lld a, lld k)
{
    lld b = 1;
    for (; k; k>>=1, a=a*a%MOD)
        if (k&1) b=b*a%MOD;
    return b;
}

void dfs1(int u, int p, int d)
{
    maxdep = max(maxdep, d);
    for (auto v : G[u]) if (v != p) dfs1(v, u, d+1);
}

lld dfs2(int u, int p, int d)
{
    int k = 0; lld dp = 0;
    for (auto v : G[u]) if (v != p) dp += dfs2(v, u, d+1), k++;
    if (k) return fpow(dp * fpow(k, MOD-2) % MOD, k);
    else return d != maxdep ? 1 : 0;
}

int main()
{
    int n; scanf("%d", &n);

    for (int i = 1, u, v; i < n; i++)
    {
        scanf("%d %d", &u, &v);
        G[u].push_back(v);
        G[v].push_back(u);
    }

    dfs1(1,0,1);
    printf("%lld\n", (1+MOD-dfs2(1,0,1))%MOD);
    return 0;
}

K. Center

#include <unordered_map>
#include <unordered_set>
using namespace std;
typedef long long lld;
typedef pair<int,int> pii;

const lld INF = 1e8;
const lld STD = 1e7;
#define X first
#define Y second

lld trans(int x, int y) { return x * INF + y; }
lld trans(pii n) { return n.X * INF + n.Y; }
pii inv(lld t) { return pii(t/INF, t%INF); }
pii operator*(pii a, int n) { return pii(a.X * n, a.Y * n); }
pii operator+(pii a, pii b) { return pii(a.X + b.X, a.Y + b.Y); }
pii operator-(pii a, pii b) { return pii(a.X - b.X, a.Y - b.Y); }

pii p[1010];
unordered_set<lld> existed;
unordered_map<lld, int> newCenter;

int main()
{
    int n, x, y, tot;
    scanf("%d", &n), tot = n;

    for (int i = 0; i < n; i++)
    {
        scanf("%d %d", &x, &y);
        x += STD, y += STD;
        p[i] = make_pair(x, y);
        existed.insert(trans(x, y));
    }

    for (int i = 0; i < n; i++)
    {
        int ans = 0;
        for (int j = 0; j < n; j++)
            if (!existed.count(trans(p[i] * 2 - p[j])))
                ans++;
        tot = min(tot, ans);
    }

    for (int i = 0; i < n; i++)
        for (int j = i+1; j < n; j++)
            newCenter[trans(p[i] + p[j])]++;

    int ptp = 0; pii nc;
    for (auto cnt : newCenter)
        if (cnt.second > ptp)
            nc = inv(cnt.first), ptp = cnt.second;
    int maybe = 0;
    for (int j = 0; j < n; j++)
        if (!existed.count(trans(nc - p[j])))
            maybe++;
    tot = min(tot, maybe);

    printf("%d\n", tot);
    return 0;
}