服务器程序的Xamarin-Java.Interop体验(二)

原本以为会比较容易跑起来demo,但其实还是我太单纯了。

那么今天来介绍一下单纯的在C#中调用Java代码段的一些解读。这样,意味着我们在本文中会直接调用Java的类,而不会在C#中进行继承、重写等。

此时需要考虑用到两个工具:class-parse和generator。

class-parse通过读取jar包字节码,推导出每个类的public、protected方法、字段,并以XML的格式输出。此工具基本上没有太大问题,可以直接使用;当然了,你不会在C#里用java的Stream API吧,所以可以考虑改一下源码来手动去掉stream api。

generator通过读取上述工具生成的XML和部分引用程序集来生成对应的.cs文件。这个工具似乎官方的进度还不够快,有很多老旧的类名称、方法都没有修改(例如JNIEnv、RegisterAttribute、JniHandleOwnership等)需要魔改后才能正式用起来。https://github.com/yang-er/java-interop 这里提供了我自己魔改的结果,不保证运行正确性、与最终发布时的设计的一致性啊~

上述程序运行完了以后,你会获得一个一串.cs文件,然后编译之后就可以在你的C#程序里运行了。注意由于截止目前还没有支持coreclr,请使用TargetFramework = net472编译,并在linux/macos上用mono运行。另外直接根据rt.jar编译出来的文件需要进行一些修改(例如让Java.Lang.Object继承于Java.Interop.JavaObject,让Java.Lang.Throwable继承于Java.Interop.JavaException)

互操作基本方法

generator将对应类的字段、函数,生成对应的JNI调用代码,C#运行时调用这个函数就会通过JNI访问Java的对应功能。

  • 每个函数都会翻译出来四个部分:
    • 一个cb_XXXX的Delegate,用于缓存互操作的时候Java的callback,在继承和重写中需要使用。
    • 一个GetXXXXHandler,用于获取或创建上述callback的委托。
    • 一个n_XXXXX_函数,是提供上述回调类似于C++的方式访问(函数签名都是IntPtr、int等基础值类型),在C#中获取对应对象并进行调用。
    • 一个对应的函数,会将传参列表转换成jvalue*数组,然后通过JniPeerMember缓存的方法信息进行调用。
  • 普通的字段会被生成成为具有getter和setter的属性

  • 具有getXXX(),setXXX(value)的一对函数也会被翻译成属性

  • Listener、Observer之类的东西则会被翻译成事件、EventArgs等

  • 抽象类、接口会生成对应的Invoker,如果C#中没有注册返回对象实际对应类型,则会使用这些Invoker来提供一个假的C#实现,否则哪来的类来调用Java方法呢(雾)

一些细节和讨论

设计是否正确?

是否有必要将get和set对翻译成属性?我个人的观点是:只翻译成对应的函数,然后提供一个属性来访问对应函数。显然这些get和set也可能被virtual override,而重写属性的话代码就会长得比较丑了。

另外对有些类型的返回处理是否有必要?例如java.lang.String和System.String之间是否有必要每次调用都转换?数组直接返回JavaArray不也挺好?有必要将java.util.Collection,java.util.Set等翻译成System.Collections.ICollection吗?虽然生成的代码更C#了,但是实际上似乎会比较影响GC和性能吧?我个人持怀疑态度。

IJavaPeerable

目前与Xamarin.Android一个很大的变化是,他们决定废弃JNIEnv这个不伦不类的类,改为使用JniEnvironment这个进行良好的整理的类。所以类的生成内容都有变化。原来的JniEnv中提供了直接对IntPtr操作的类,现在由JniObjectReference提供对应的方法来复制,整理的更加“干净”。

在Xamarin团队决定将互操作支持带到桌面上的时候,他们一开始使用了SafeHandle来代替原来的IntPtr,但是发现性能下降明显,所以后期他们全部改成了JniObjectReference。目前的generator大部分还都返回IntPtr+JniHandleOwnership,你需要改成ref JniObjectReference+JniObjectReferenceOptions。

除此之外,与初代实现的不同一点是,

类型系统相容性

显然Java中,Throwable是继承于Object的,但是如果想在C#中强类型处理Java异常,Throwable就不能再继承于Object了,除非之后CLR规范修改(雾)

另外目前的Generator生成出来的并没有泛型,全部都是平铺直叙的类。如果想支持C#那样的泛型,需要后期他们继续增加支持,目前你需要自己写一些胶水代码(继承、重写、cast)来“支持”。

另外Java还支持重写某函数以后返回比父类更具体的子类类型,这一点C#是不支持的,所以你可能需要修改生成的胶水代码才能编译。

性能

这套框架走JNI,所以其实性能其实不会太差?但是需要注意的是,这套框架目前翻译Java数组、CharSequence的时候,会有Java数组内容复制到C#数组,和C#数组内容复制到Java数组里,这两个过程,你需要非常小心,尽量在胶水中少使用数组,多使用ArrayList等。

完成进度

我怎么总觉得按他们的速度,这个功能会跳票啊?(大雾)

服务器程序的Xamarin-Java.Interop体验(一)

这几天需要写一个用到Java模块的程序,但是Java是不可能写的,这辈子都不可能写的,只能搞搞interop了。

目前市面上已有的基本上是IKVM.NET和JNBridgePro,后者没太了解技术细节,前者看起来是只有单向的互操作(JVM是跑在CLR上的,或者将Java字节码翻译到MSIL)。

想起来之前好像说.NET 5.0要支持Java互操作,但是翻了翻dotnet/runtime库,丝毫看不出来仓库内在搞支持。后来就想了想,换了xamarin/java.interop库研究看看。

按照之前Xamarin.Android的做法的话,互操作应该是双向的。C#这边可以继承Java的类,然后Java那边也会生成访问对应C#代码的代码。

然后发现……他们正在支持.NET Core 3.1,但是其JNI库引用的头文件还是mono的,而且用到了pthread和dlfcn的头文件,也就是说……现在必须在Linux/macOS和mono下运行。

那么先来build一下吧~

此处以Ubuntu 18.04为例。首先需要准备一些系统依赖。编译要很久,还是选择apt安装吧。

sudo apt install gnupg ca-certificates
sudo apt-key adv --keyserver hkp://keyserver.ubuntu.com:80 --recv-keys 3FA7E0328081BFF6A14DA29AA6A19B38D3D831EF
echo "deb https://download.mono-project.com/repo/ubuntu stable-bionic main" | sudo tee /etc/apt/sources.list.d/mono-official-stable.list
sudo apt update
sudo apt install openjdk-8-jdk mono-devel nuget dotnet-sdk-3.1
sudo ln -s /usr/include/mono-2.0/mono /usr/include/mono

编译的时候TargetFrameworks要用到netcoreapp3.1,所以得安装上。然后就是编译内容了。

先clone一下。

git clone https://github.com/xamarin/java.interop --depth=1
cd java.interop

然后先简单修改一下几个文件。

diff --git a/Directory.Build.props b/Directory.Build.props
index 521e68a..1da7d44 100644
--- a/Directory.Build.props
+++ b/Directory.Build.props
@@ -43,6 +43,8 @@
     <XamarinAndroidToolsDirectory   Condition=" '$(XamarinAndroidToolsDirectory)' == '' ">$(MSBuildThisFileDirectory)external\xamarin-android-tools</XamarinAndroidToolsDirectory>
   </PropertyGroup>
   <PropertyGroup>
+    <JavaCPath>/usr/lib/jvm/java-8-openjdk-amd64/bin/javac</JavaCPath>
+    <JarPath>/usr/lib/jvm/java-8-openjdk-amd64/bin/jar</JarPath>
     <JavacSourceVersion Condition=" '$(JavacSourceVersion)' == '' ">1.8</JavacSourceVersion>
     <JavacTargetVersion Condition=" '$(JavacTargetVersion)' == '' ">1.8</JavacTargetVersion>
     <_BootClassPath Condition=" '$(JreRtJarPath)' != '' ">-bootclasspath "$(JreRtJarPath)"</_BootClassPath>
diff --git a/samples/Hello/Program.cs b/samples/Hello/Program.cs
index 6ffacbb..9f45fac 100644
--- a/samples/Hello/Program.cs
+++ b/samples/Hello/Program.cs
@@ -10,6 +10,7 @@ namespace Hello
                public static unsafe void Main (string[] args)
                {
                        Console.WriteLine ("Hello World!");
+                       JreRuntime.Initialize("/usr/lib/jvm/java-8-openjdk-amd64/jre/lib/amd64/server/libjvm.so");
                        try {
                                var ignore = JniRuntime.CurrentRuntime;
                        } catch (InvalidOperationException e) {
diff --git a/src/Java.Interop/Java.Interop/JniRuntime.cs b/src/Java.Interop/Java.Interop/JniRuntime.cs
index 6de9021..f9fa0de 100644
--- a/src/Java.Interop/Java.Interop/JniRuntime.cs
+++ b/src/Java.Interop/Java.Interop/JniRuntime.cs
@@ -149,7 +149,8 @@ namespace Java.Interop
                                Debug.Assert (count == 0);
                                var available   = GetAvailableInvocationPointers ().FirstOrDefault ();
                                if (available == IntPtr.Zero)
-                                       throw new NotSupportedException ("No available Java runtime to attach to. Please create one.");
+                                       return null;
+                                       //throw new NotSupportedException ("No available Java runtime to attach to. Please create one.");
                                var options     = new CreationOptions () {
                                        DestroyRuntimeOnDispose = false,
                                        InvocationPointer       = available,
diff --git a/src/Java.Runtime.Environment/Java.Interop/JreRuntime.cs b/src/Java.Runtime.Environment/Java.Interop/JreRuntime.cs
index ea1489f..9ca06b0 100644
--- a/src/Java.Runtime.Environment/Java.Interop/JreRuntime.cs
+++ b/src/Java.Runtime.Environment/Java.Interop/JreRuntime.cs
@@ -72,6 +72,14 @@ namespace Java.Interop {

        public class JreRuntime : JniRuntime
        {
+               public static void Initialize(string path)
+               {
+                       int r = NativeMethods.java_interop_jvm_load (path);
+                       if (r != 0) {
+                               throw new Exception ($"Could not load JVM path `{path}` ({r})!");
+                       }
+               }
+
                static int CreateJavaVM (out IntPtr javavm, out IntPtr jnienv, ref JavaVMInitArgs args)
                {
                        return NativeMethods.java_interop_jvm_create (out javavm, out jnienv, ref args);

另外,OpenJDK11应该也是可用的,不过得注意JavacSourceVersion和JavacTargetVersion=11,由于使用的部分代码还是java8标准所以建议继续JavacSourceVersion=1.8。暂时还没实验jdk11。

文件差不多编辑完了,来编译。

make src/Java.Runtime.Environment/Java.Runtime.Environment.dll.config
make
mono bin/TestDebug/Hello.exe

此时会显示运行成功的样子。如果没成功,那就是我忘了哪个步骤没写(逃)

Hello World!
Part 2!
# JniEnvironment.EnvironmentPointer=94212541059552
vm.SafeHandle=140206052962432
java.lang.Object=0x55af91090e50/L
hashcode=1735600054
WITHIN: GetCreatedJavaVMs: 140206052962432
POST: GetCreatedJavaVMs: 140206052962432

接下来的文章将大致介绍如何在C#中直接调用Java代码,而不是JniType一顿操作。

EFCore查询语句生成流程、让EFCore支持批量Update/Delete/MergeInto

引子

之前发现了一款叫 EFCore.BulkExtensions 的 nuget 包。里面提供了大量的 BulkInsertOrUpdateOrDelete 和 BatchUpdate 的拓展,可以很方便的解决批量更新和删除的问题,不用让 EFCore 一条一条的删除和更新。

其中几个比较有用的函数签名是

Task<int> BatchDeleteAsync(this IQueryable<T> queryable);
Task<int> BatchUpdateAsync(this IQueryable<T> queryable, Expression<Func<T, T>> updateExpression);

但是在升级到 ASP.NET Core 3.1 的时候,所有 Where 中的 someArray.Contains(i.Key) 全部挂掉了。而我的程序里用这一语句比较多,遂下载了其源代码并合并了当时作者几个月都没合并的一个PR。

研究代码,总结了该程序的基本运行过程:

  1. 通过反射获取各种私有变量来访问到 DbContext
  2. updateExpression 由这个包自己访问表达式树获得
  3. 让 IQueryable 执行 GetEnumerator 让 EFCore 生成对应的 Select 语句,进行字符串拼接
  4. 由 DbContext.Database.ExecuteSqlRaw 来完成语句执行

但是这过程有几个问题:

  1. 有几种句式 updateExpression 会翻译不了
  2. 由其原来实现的 updateExpression 翻译后的某些参数的 SQL 类型不对
  3. 我需要一个 INSERT INTO SELECT FROM 的句式,它不支持
  4. 我需要一个 upsert 功能,但是原来的 BulkInsertOrUpdate 不能在原表基础上操作

遂研究 IQueryable.Provider.Execute<T> 是什么执行流程。

语句生成过程

我觉得在翻代码的过程中,有这么一首歌比较符合我的心情:如果你愿意一层一层一层一层的拨开我的心,你会发现,你会讶异,你是我最压抑最深处的秘密。

  1. 调用 QueryCompiler.ExtractParameters,将其中的闭包捕捉变量参数化
  2. 检查是否已经缓存了这个查询表达式,如果没有则转入 QueryCompilationContext 处理,否则转到8
  3. QueryTranslationPreprocessor 处理,在原来的表达式树上先跳舞
  4. QueryableMethodTranslatingExpressionVisitor 将原来的表达式树翻译成一个 ShapedQueryExpression,而这一个表达式则包含了几个部分:SelectExpressionShaperExpressionResultCardinality。其中前者是可以翻译成 SQL 语句的表达式,中间的是将查询出来的元组映射到实体类型,最后一个是查询的维度(Enumerable、Single、SingleOrDefault)
  5. QueryTranslationPostprocessor 处理,其中比较重要的是将查询的字段加入 SELECT 的 Projection 列表
  6. ShapedQueryCompilingExpressionVisitorShapedQueryExpression 缓存,并转换成为 IRelationcalCommandCache,然后构造一个 QueryableEnumerable 的 NewExpression。前者包含了该查询语句需要的参数、查询语法树、查询字符串,后者是进行语句执行的类
  7. 将上述 NewExpression 和将 QueryCompilationContext 中的查询参数加到 QueryContext 中的语句合并成为一个代码块,然后 Lambda Compile
  8. 生成 DbCommandIRelationcalCommandCache 获取字符串并加入各种参数进行查询

翻译结束了,查询到这里也就可以开始了。

支持批量操作?

IRelationalCommandCache 是怎么生成字符串的呢?没错,就是 QuerySqlGenerator 啦。

那么,也就是说,我们能过拿到 Select Expression 的话,一切都好说。

上述过程中,最后的 IRelationalCommandCache 中会包含这个 SelectExpression。我们可以魔改这个啊!

DELETE 语句的生成比较简单。我们构建一个 DeleteExpression 类,将要删除的 Table、删除中的 Predicate、删除个数限制 Limit、原来的一些 Join 全部获取出来,就好了。然后在我们自己继承的 SqlServerQuerySqlGenerator 中实现这个部分。

INSERT INTO SELECT 也比较简单,只要构建一个 InsertIntoSelectExpression 类,将要插入的表 Table 和 SelectExpression 保存起来,就好了。

UPDATE SET 可能比较麻烦。但是我们可以骚操作啊!将那个 updateExpression 变成 Select 的字段,然后再读取 SelectExpression 中的 ProjectionExpression 不就好了吗~我真是个小天才。

MERGE INTO 是最烦的,因为结构过于复杂,涉及到 Target、Source、JoinPredicate、Limit、Matched、NotMatchedByTarget、NotMatchedBySource。过程中还要实现一些表的更名之类的。目前我只是实现了这些,但是想做出 Matched When 功能以后再发布到 nuget 上,这个实现实在是过于复杂,不知道有没有人帮帮我啊 TAT。

由于翻译 SqlExpression 最方便还是基于 QuerySqlGenerator 操作,所以就写一个 EnhancedQuerySqlGenerator 类来满足我们的需求,并在 DbContextOptionsBuilder 那边将这个 Factory 替换掉。

实现了这些,GitHub 地址:Microsoft.EntityFrameworkCore.Bulk,可以在 github packages 上下载目前版本的 nuget 包。

另外 src/Internal/TranslationGoThrough.cs 中有上述语句生成过程的一个缩影,和系统版本几乎一致,唯一不同的是修改了 ExtractParameters 函数。

因为原来的 Extract 过程有一个事情很诡异:在生成参数的时候,我们可以进行一些本地执行,但是如果不阻止某些本地执行程的话,可能会导致 UPDATE 语句的字段全部空。例如 updateExpression 中没有利用到原表的参数并且不捕捉闭包变量的时候,那么不会被本地执行,但是如果没有利用到原表的参数还捕捉闭包变量的时候,它就会被直接本地执行,字段空啦~(确实不懂他们这段代码逻辑怎么写的,你生成查询的时候优化这个的话,怎么不把前面一个也优化掉啊……

ASP.NET Core 修改 EndpointRouting 的链接生成行为

微软在 ASP.NET Core 2.2 时期引入了 EndpointRouting,并且在之后的 3.0 / 3.1 进行了很多的升级改造。我前段时间刚刚把网站升级到了 3.1,之前一直在使用 2.1 的兼容模式,并且内部有很多的属性路由的配置,而没有使用 DefaultRoute 那条规则。按照官方的文档将 UseMvc 那套替换成了 EndpointRouting 那一套。然后发现,很多的链接生成出了问题。

由于之前略微读过 AnchorTagHelper 的源码,大约知道问题出在了 UrlHelper 身上,将。随后发现, 在使用终结点路由后,UrlHelper 的实现类型变成了 EndpointRoutingUrlHelper。而后者则在内部调用了 LinkGenerator,其默认实现为 DefaultLinkGenerator

LinkGenerator 的实现使用到了几个比较重要的对象:EndpointDataSource 是所有终结点的集合,TemplateBinderFactory 是根据终结点的 RoutePattern 生成 TemplateBinder 的工厂,TemplateBinder 是一个保存了 RoutePattern 内部信息(例如默认RouteValues、链接模板)并提供实际链接生成的工具。由于微软在这方面的代码编写中全都 internal 了,且代码中注释较少,在接下来我将介绍他们的基本工作流程,并给出一个能够基本做到“兼容”、不修改太多代码的方法。此处的解释仅仅针对属性路由。

寻找终结点

请关注 GetEndpoints<TAddress>(TAddress address) 函数。此处,TAddress 取值有两种:string、RouteValuesAddress。

当 TAddress 为 string 时,请回顾 RouteAttribute 中的 Name 属性。没错就是这个,在 anchor 链接使用 asp-route=”something” 的时候,路由的查找是根据此处进行的。在生成终结点时,程序已经确保了所有的 RouteName 都唯一。查找直接找字典就好了。

当 TAddress 为 RouteValuesAddress 时,是使用 asp-area、asp-controller、asp-page、asp-action、asp-route-xxx 时进行的寻址方式。此 Address 由三个部分组成,其中包括 RouteName、ExplicitValues、AmbientValues。根据 RoutePattern 中的 RequiredValues 来检查 ExplicitValues 和 AmbientValues 的值。

结束后返回一个终结点列表,其中包含可能匹配的结果。我们针对每个可能的终结点,尝试匹配路由模板。

尝试匹配路由模板

生成一个 TemplateBinder,这个对象提供三个函数:GetValues、TryProcessConstraints、TryBindValues。

GetValues 会根据 RoutePattern 和两种 RouteValue 生成一个最终匹配列表,使用了 Ambient Values Invalidation Algorithm,在接下来介绍。

TryProcessConstraints 是根据上方匹配,检查是否满足路由值的限制。

TryBindValues 是将获取的 RouteValues 生成最后的终结点链接。

隐式路由值失效算法

首先介绍几个字段,在接下来会被算法使用到。

  • _slots 是一个 KVP<string, object>,其 Key 为路由值的名称,Value 为 null。其中前几个字段依次是 pattern 的参数,后几个字段是filter默认值。

  • _requiredKeys 是一个 string 数组,是以 RoutePattern 的字段中提取的。

  • _defaults 是默认路由值的字典。

  • _pattern 是 RoutePattern。

  • ambientValues 是目前页面的路由值字典。

  • values 是即将导航的页面的路由值字典,一般由asp-action等等计算得到。

最开始会将“需要填的空格”复制一份,以供本次计算使用。

  • 首先将所有需要填的空在 values 中寻找一遍,决定使用该值或者 explicit null 值。

  • 考虑针对每个 requiredKeys 是否继续复制隐式路由值。

    • 如果 ambientValues 为空引用,那么不复制。
    • 对于每一个在上一步中确定的显式值,检查和隐式路由值是否相同。检查过程会考虑值相等、any值匹配、显式null值。
  • 对于每一个路由参数,检查是否存在显式值、隐式值。
    • 如果还可以复制隐式值,那么检查此处的显式和隐式值是否匹配。如果不匹配,那么不再复制后面的隐式值。
    • 如果不能再复制隐式值,并且没有显式值,但它确实是路由需要的值,而且存在隐式值,并且和要求的值相等,那么使用它。
    • 如果在以上两种情况中匹配成功,那么将它加入 accepted 匹配成功的路由字典。
    • 如果是可选参数或者通配参数,那么将它从路由字典中忽略。
    • 如果没有匹配但是有默认值,那么使用默认值。
    • 以上过程均没有对应的话,认为此Endpoint匹配失败,不适用于此结果。
  • 对每个filter字段进行检查。
    • 如果存在显式值,检查两者是否相等,如果不相等那么就是匹配失败
    • 如果不存在显式值,那么不加入 accepted 匹配成功的路由字典。
  • 如果存在没有处理的显式值,加入 combined 合并后的路由字典。

  • 对于一些非参数的隐式值,需要加入到 combined 字典,这样可以对路由限制可见。

破坏的部分

以上对于 /[area]/{area_id}/[controller]/[action] 非常的不友好。

可以通过将该 TemplateBinder 的 _requiredValues 清空来达到这一点操作。但是这是 readonly 的 private 字段,所以使用反射解决。

using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Routing.Template;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using System;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;

namespace Microsoft.AspNetCore.Routing
{
    public sealed class OrderLinkGenerator : LinkGenerator, IDisposable
    {
        private readonly LinkGenerator inner;
        private readonly TemplateBinderFactory _binderFactory;
        private readonly Func<RouteEndpoint, TemplateBinder> _createTemplateBinder;
        private readonly FieldInfo _requiredKeys;
        const string typeName = "Microsoft.AspNetCore.Routing.DefaultLinkGenerator";
        internal static Type typeInner;

        public OrderLinkGenerator(
            ParameterPolicyFactory parameterPolicyFactory,
            TemplateBinderFactory binderFactory,
            EndpointDataSource dataSource,
            IOptions<RouteOptions> routeOptions,
            IServiceProvider serviceProvider)
        {
            if (typeInner.FullName != typeName)
                throw new NotImplementedException();
            var logger = serviceProvider.GetService(typeof(ILogger<>).MakeGenericType(typeInner));
            var autoFlag = BindingFlags.NonPublic | BindingFlags.Instance;

            var args = new object[]
            {
                parameterPolicyFactory,
                binderFactory,
                dataSource,
                routeOptions,
                logger,
                serviceProvider
            };

            var ctorInfo = typeInner.GetConstructors()[0];
            var newExp = Expression.New(ctorInfo, args.Select(o => Expression.Constant(o)));
            var ctor = Expression.Lambda<Func<LinkGenerator>>(newExp).Compile();
            inner = ctor();

            _binderFactory = binderFactory;
            _createTemplateBinder = CreateTemplateBinder;
            var fieldInfo = typeInner.GetField(nameof(_createTemplateBinder), autoFlag);
            fieldInfo.SetValue(inner, _createTemplateBinder);

            _requiredKeys = typeof(TemplateBinder).GetField(nameof(_requiredKeys), autoFlag);
        }

        private TemplateBinder CreateTemplateBinder(RouteEndpoint endpoint)
        {
            /*
             * The following code section is disabled
             * for its change to RoutePattern may cause
             * errors.
             * 
             * var rawText = endpoint.RoutePattern.RawText;
             * var rv = endpoint.RoutePattern.RequiredValues as RouteValueDictionary;
             *
             * if (rawText != null)
             * {
             *     var m = Regex.Matches(rawText, "\\{(\\w+)\\}");
             *     for (int i = 0; i < m.Count; i++)
             *         rv.Add(m[i].Value.TrimStart('{').TrimEnd('}'), RoutePattern.RequiredValueAny);
             * }
             * 
             * A better solution is to disable the _requiredKeys.
             */

            var binder = _binderFactory.Create(endpoint.RoutePattern);
            _requiredKeys.SetValue(binder, Array.Empty<string>());
            return binder;
        }

        public override string GetPathByAddress<TAddress>(HttpContext httpContext, TAddress address, RouteValueDictionary values, RouteValueDictionary ambientValues = null, PathString? pathBase = null, FragmentString fragment = default, LinkOptions options = null) =>
            inner.GetPathByAddress(httpContext, address, values, ambientValues, pathBase, fragment, options);
        public override string GetPathByAddress<TAddress>(TAddress address, RouteValueDictionary values, PathString pathBase = default, FragmentString fragment = default, LinkOptions options = null) =>
            inner.GetPathByAddress(address, values, pathBase, fragment, options);
        public override string GetUriByAddress<TAddress>(HttpContext httpContext, TAddress address, RouteValueDictionary values, RouteValueDictionary ambientValues = null, string scheme = null, HostString? host = null, PathString? pathBase = null, FragmentString fragment = default, LinkOptions options = null) =>
            inner.GetUriByAddress(httpContext, address, values, ambientValues, scheme, host, pathBase, fragment, options);
        public override string GetUriByAddress<TAddress>(TAddress address, RouteValueDictionary values, string scheme, HostString host, PathString pathBase = default, FragmentString fragment = default, LinkOptions options = null) =>
            inner.GetUriByAddress(address, values, scheme, host, pathBase, fragment, options);
        public void Dispose() => ((IDisposable)inner).Dispose();
    }
}

最后在依赖注入容器中替换即可。

public static IMvcBuilder ReplaceLinkGenerator(this IMvcBuilder mvc)
{
    var old = mvc.Services.FirstOrDefault(s => s.ServiceType == typeof(LinkGenerator));
    OrderLinkGenerator.typeInner = old.ImplementationType;
    mvc.Services.Replace(ServiceDescriptor.Singleton<LinkGenerator, OrderLinkGenerator>());
    return mvc;
}

ASP.NET Core中使用Basic Authentication

前段时间需要给某个项目接入Web API,并且使用的是基于HTTP Header的Basic Authorzation。

最开始的写法是在appsettings.json里开辟字段存储用户名和密码,然后给对应的Controller加IActionFilter,后来随着项目的升级,想着接入ASP.NET Core自带的用户系统。

翻了翻Microsoft Docs,并没有找到一些科学的内容。后来在NuGet上找到了idunno.Authentication.Basic这样的一个NuGet包。

这是其GitHub Repo上提供的代码示例。

public void ConfigureServices(IServiceCollection services)
{
    services.AddAuthentication(BasicAuthenticationDefaults.AuthenticationScheme)
            .AddBasic(options =>
            {
                options.Realm = "idunno";
                options.Events = new BasicAuthenticationEvents
                {
                    OnValidateCredentials = context =>
                    {
                        if (context.Username == context.Password)
                        {
                            var claims = new[]
                            {
                                new Claim(
                                    ClaimTypes.NameIdentifier, 
                                    context.Username, 
                                    ClaimValueTypes.String, 
                                    context.Options.ClaimsIssuer),
                                new Claim(
                                    ClaimTypes.Name, 
                                    context.Username, 
                                    ClaimValueTypes.String, 
                                    context.Options.ClaimsIssuer)
                            };

                            context.Principal = new ClaimsPrincipal(
                                new ClaimsIdentity(claims, context.Scheme.Name));
                            context.Success();
                        }

                        return Task.CompletedTask;
                    }
                };
            });

    // All the other service configuration.
}

在此之后,给需要使用Basic的Controller添加

[Authorize(AuthenticationSchemes = "Basic")]

即可。


众所周知,ASP.NET Core使用的是基于Claim的认证与授权。

如果想使用自带的用户系统(services.AddIdentity()那套),则直接使用SignInManager就好啦。

那么大致的验证代码(OnValidateCredentials)就是

private static async Task ValidateAsyncOld(ValidateCredentialsContext context)
{
    var signInManager = context.HttpContext.RequestServices
        .GetRequiredService<SignInManager<TUser>>();
    var userManager = signInManager.UserManager;
    var user = await userManager.FindByNameAsync(context.Username);

    if (user == null)
    {
        context.Fail("User not found.");
        return;
    }

    var checkPassword = await signInManager.CheckPasswordSignInAsync(user, context.Password, false);
    if (!checkPassword.Succeeded)
    {
        context.Fail("Login failed, password not match.");
        return;
    }

    context.Principal = await signInManager.CreateUserPrincipalAsync(user);
    context.Success();
}

此后发现验证的性能问题。

显然每次访问API接口都会进行四次数据库查询!(Users一次,UserClaims一次,Roles一次,RoleClaims一次)

在轮询的接口中出现了非常严重的性能问题,毕竟主要程序代码只有一个查询,而用户信息就要查询四次。


不过不要紧,我们还有MemoryCache。

将用户信息缓存在内存中,并且定期清除,让几百次登录同一账户查询才生成一次ClaimsIdentity就会将查询用户信息的时间均摊的更低了。

同时由于我的这个项目不使用UserClaims和RoleClaims,也可以在ClaimsIdentity生成过程中忽略这两个步骤。在用户信息的查询中也可以省略掉一堆一堆一堆一堆不需要的字段。

根据SignInManager、UserManager等的源码,最后挖掘出来了四个接口需要使用

  • IPasswordHasher<TUser>
  • IOptions<IdentityOptions>
  • IdentityDbContext<...>
  • IUserClaimsPrincipalFactory<TUser> (已经展开到最后代码中)

这样就不需要使用UserManager、SignInManager而底层的直接生成ClaimsIdentity和ClaimsPrincipal。

private readonly static IMemoryCache _cache =
    new MemoryCache(new MemoryCacheOptions()
    {
        Clock = new Microsoft.Extensions.Internal.SystemClock()
    });

private static async Task ValidateAsync(ValidateCredentialsContext context)
{
    var dbContext = context.HttpContext.RequestServices
        .GetRequiredService<TContext>();
    var normusername = context.Username.ToUpper();

    var user = await _cache.GetOrCreateAsync("`" + normusername.ToLower(), async entry =>
    {
        var value = await dbContext.Users
            .Where(u => u.NormalizedUserName == normusername)
            .Select(u => new { u.Id, u.UserName, u.PasswordHash, u.SecurityStamp })
            .FirstOrDefaultAsync();
        entry.AbsoluteExpirationRelativeToNow = TimeSpan.FromMinutes(5);
        return value;
    });

    if (user == null)
    {
        context.Fail("User not found.");
        return;
    }

    var passwordHasher = context.HttpContext.RequestServices
        .GetRequiredService<IPasswordHasher<TUser>>();

    var attempt = passwordHasher.VerifyHashedPassword(
        user: default, // assert that hasher don't need TUser
        hashedPassword: user.PasswordHash,
        providedPassword: context.Password);

    if (attempt == PasswordVerificationResult.Failed)
    {
        context.Fail("Login failed, password not match.");
        return;
    }

    var principal = await _cache.GetOrCreateAsync(normusername, async entry =>
    {
        var uid = user.Id;
        var ur = await dbContext.UserRoles
            .Where(r => r.UserId.Equals(uid))
            .Join(dbContext.Roles, r => r.RoleId, r => r.Id, (_, r) => r.Name)
            .ToListAsync();

        var options = context.HttpContext.RequestServices
            .GetRequiredService<IOptions<IdentityOptions>>().Value;

        // REVIEW: Used to match Application scheme
        var id = new ClaimsIdentity("Identity.Application",
            options.ClaimsIdentity.UserNameClaimType,
            options.ClaimsIdentity.RoleClaimType);
        id.AddClaim(new Claim(options.ClaimsIdentity.UserIdClaimType, $"{user.Id}"));
        id.AddClaim(new Claim(options.ClaimsIdentity.UserNameClaimType, user.UserName));
        id.AddClaim(new Claim(options.ClaimsIdentity.SecurityStampClaimType, user.SecurityStamp));
        foreach (var roleName in ur)
            id.AddClaim(new Claim(options.ClaimsIdentity.RoleClaimType, roleName));
        var value = new ClaimsPrincipal(id);

        entry.AbsoluteExpirationRelativeToNow = TimeSpan.FromMinutes(20);
        return value;
    });

    context.Principal = principal;
    context.Success();
}

缓存的5分钟/20分钟可以根据需求来修改。

这样就有一个似乎“完美”一点的Basic Authentication啦!

2019ICPC西安 C. Dirichlet k-th root

我不会狄利克雷卷积。输得彻底。

给出一个很直观的解法代码。

前提是理解好 f(n) = xf^k(n) = kx+b_{k,n}

那么

f^{2k}(n) = (f^k * f^k)(n) = 2f^k(n) + \sum_{d|n, d\not=1,d\not=n} f^k(d) f^k(n/d)

f^{k+1}(n) = (f^k * f)(n) = f^k(n) + f(n) + \sum_{d|n, d\not=1,d\not=n} f(d) f^k(n/d)

首先显然 f^k(1) = 1,假设已经求出 f(1) \cdots f(n-1) 里的所有值,以及可能取到的上标 k,那么就可以得到 g(n) = f^K(n) = Kx+b_{K,n}。我们可以从小往大递推 b_{i,n},然后求得 b_{K,n} 后得到 f(n),并解出 f^k(n) 以方便后续计算。

#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5+5;
const int mod = 998244353;
int g[maxn], f[maxn][64], b[maxn][64];
vector<int> d[maxn];
int n, k, kk[64], invk, tok;

int fpow(long long a, int k) {
    long long b = 1;
    for (; k; k >>= 1, a = a * a % mod)
        if (k & 1) b = b * a % mod;
    return b;
}

int main()
{
    scanf("%d %d", &n, &k);
    for (int i = 1; i <= n; i++)
        scanf("%d", &g[i]);
    for (int i = 2; i <= n; i++)
        for (int j = i+i; j <= n; j += i)
            d[j].push_back(i);
    for (int tmp = k; tmp; ) {
        kk[++tok] = tmp;
        if (tmp % 2 == 1) tmp--; else tmp /= 2;
    } // the used ks
    invk = fpow(k, mod-2);

    // find for each f(n)
    for (int i = 1; i <= tok; i++) f[1][i] = 1;
    for (int i = 2; i <= n; i++)
    {
        // if we know about f[1..n-1], let f[n] = x
        // so it can be proved that fk[n] = kx+b
        for (int j = tok-1; j; j--)
        {
            long long tmp = b[i][j+1];
            if (!(kk[j] & 1)) tmp <<= 1;
            int sec = (kk[j] & 1) ? tok : j+1;
            for (auto dd : d[i])
                tmp += 1ll * f[dd][j+1] * f[i/dd][sec] % mod;
            b[i][j] = tmp % mod;
        }

        // now we have g[n] = Kx+b, just...
        f[i][tok] = 1ll * (g[i] - b[i][1] + mod) * invk % mod;
        for (int j = tok-1; j; j--)
            f[i][j] = (1ll * kk[j] * f[i][tok] + b[i][j]) % mod;
        assert(f[i][1] == g[i]);
    }

    for (int i = 1; i <= n; i++)
        printf("%d ", f[i][tok]);
    return 0;
}

复杂度为 O(n\log n \log k)

另外给出一个道听途说的做法,当时完全没有想到的。

如果你手算能力够强,那么可以发现 b 是个类似于求和的东西,结果对 k 敏感,如果令 k = 998244353,那么求和结果为 0,对 g\operatorname{inv}k 次也可以得到 f

#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5+5;
const int mod = 998244353;
int g[maxn], f[maxn], n;

int fpow(long long a, int k) {
    long long b = 1;
    for (; k; k >>= 1, a = a * a % mod)
        if (k & 1) b = b * a % mod;
    return b;
}

void conv(int a[maxn], const int b[maxn]) {
    static int c[maxn];
    fill(c, c+n+1, 0);
    for (int i = 1; i <= n; i++)
        for (int j = 1; j * i <= n; j++)
            c[j*i] = (c[j*i] + 1ll * a[i] * b[j]) % mod;
    copy(c, c+n+1, a);
}

int main()
{
    int k; scanf("%d %d", &n, &k);
    for (int i = 1; i <= n; i++)
        scanf("%d", &g[i]);
    k = fpow(k, mod-2); f[1] = 1;
    for (; k; k >>= 1, conv(g, g))
        if (k & 1) conv(f, g);
    for (int i = 1; i <= n; i++)
        printf("%d ", f[i]);
    return 0;
}

不太会证明,能AC,复杂度为 O(n\log n \log \operatorname{inv} k)

2018ICPC北京 C. Pythagorean triple

数出所有满足 a^2+b^2=c^2c \le n 的正整数三元组。

首先,如果 (x,y,z) 是一个勾股数三元组,那么 (ax,ay,az) 也是个勾股数三元组。那么,不妨考虑 f(n) 为斜边长为 n 的勾股三元组数量,g(n) 为斜边长为 n 的本原勾股数,那么有

F(n) = \sum_{i=1}^n f(i) = \sum_{i=1}^n \sum_{d|i} g(d) = \sum_{d=1}^n \sum_{i=1}^{\lfloor n/d \rfloor} g(i) = \sum_{d=1}^n G(\left\lfloor\frac{n}{d}\right\rfloor)

考虑到所有本原勾股三元组可以表示为 (i^2-j^2, 2ij, i^2+j^2),其中 \gcd(i,j)=1,并且不能同时为奇数,那么有

\begin{aligned} G(n) & = \sum_{i 为奇} \sum_{j 为偶} [i^2+j^2 \le n] [\gcd(i,j)=1] \\ & = \sum_{d=1}^{\sqrt{n/2}} \mu(d) \sum_{i 为奇} \sum_{j 为偶} [i^2+j^2 \le n] [d|i] [d|j] \\ & = \frac{1}{2} \sum_{d=1}^{\sqrt{n/2}} \mu(d) \left( \sum_{i} \sum_{j} – \sum_{i 为奇} \sum_{j 为奇}\right ) [i^2+j^2 \le n] [d|i] [d|j] \\ & = \frac{1}{2} \sum_{d=1}^{\sqrt{n/2}} \mu(d) \left( \sum_{i} \left\lfloor\frac{\sqrt{n-i^2d^2}}{d}\right\rfloor – [d为奇] \sum_{i 为奇} \left\lfloor\frac{\frac{\sqrt{n-i^2d^2}}{d}+1}{2}\right\rfloor \right ) \end{aligned}

感觉写到这里一脸复杂度会超掉的样子,那么继续化简试试。

\begin{aligned} F(n) & = \sum_{c=1}^n G(\left\lfloor\frac{n}{c}\right\rfloor) \\ & = \frac{1}{2} \sum_{c=1}^n \sum_{d=1}^{\sqrt{n/2c}} \mu(d) \left( \sum_{i} \left\lfloor\frac{\sqrt{n/c-i^2d^2}}{d}\right\rfloor – [d为奇] \sum_{i 为奇} \left\lfloor\frac{\frac{\sqrt{n/c-i^2d^2}}{d}+1}{2}\right\rfloor \right ) \\ & = \frac{1}{2} \sum_{d=1}^{\sqrt{n/2}} \mu(d) \sum_{c=1}^n \left( \sum_{i} \left\lfloor\sqrt{\frac{n}{cd^2}-i^2}\right\rfloor – [d为奇] \sum_{i 为奇} \left\lfloor\frac{1}{2}\left(\sqrt{\frac{n}{cd^2}-i^2}+1\right)\right\rfloor \right ) \end{aligned}

那么可以发现最后那个对 c 求和的东西对外只和 n/d^2 有关系,对内可以数论分块,一脸时间复杂度 O(n^{3/4}) 的模样,复杂度够了,加个记忆化。冲冲冲~

#include <bits/stdc++.h>
using namespace std;
typedef long long lld;

inline int sqrtd(int x)
{
    int tot = sqrtl(x);
    while (tot * tot < x) tot++;
    while (tot * tot > x) tot--;
    return tot;
}

inline pair<lld,lld> c(int n)
{
    static unordered_map<int,pair<lld,lld>> _C;
    if (_C.count(n)) return _C[n];
    lld p1 = 0, p2 = 0; int m = sqrtd(n);

    for (int i = 1, j = m; i <= m; i++)
    {
        while (i * i + j * j > n) j--;
        p1 += j;
        if (i & 1) p2 += (j+1)/2;
    }

    return _C[n] = make_pair(p1, p2);
}

inline lld h(int n, bool odd)
{
    lld ans = 0;

    for (int L = 1, R; L <= n; L = R+1)
    {
        R = n / ( n / L );
        auto cur = c(n/L);
        ans += (R-L+1) * cur.first;
        if (odd) ans -= (R-L+1) * cur.second;
    }

    return ans;
}

int main()
{
    const int MAXN = 1e6+5;
    static int pri[MAXN], pcn;
    static char isnp[MAXN], miu[MAXN];
    miu[1] = 1;

    for (int i = 2; i < MAXN; i++)
    {
        if (!isnp[i]) pri[pcn++] = i, miu[i] = -1;
        for (int j = 0; j < pcn && i * pri[j] < MAXN; j++)
        {
            isnp[i * pri[j]] = 1;
            if (i % pri[j] == 0) break;
            miu[i * pri[j]] = -miu[i];
        }
    }

    int T, n;
    scanf("%d", &T);

    while (T--)
    {
        scanf("%d", &n);
        lld ans = 0;
        for (int d = 1; d * d <= n; d++)
            if (miu[d]) ans += miu[d] * h(n/d/d, d&1);
        printf("%lld\n", ans/2);
    }

    return 0;
}

2019CCPC秦皇岛 H. Houraisan Kaguya

题目链接

首先,如果你的群论足够好,你应该能理解到,

f(a,b) = \frac{\operatorname{ord}a}{\gcd(\operatorname{ord}a,\operatorname{ord}b)}

所以题目实际上是求

\sum_{i=1}^n \sum_{j=1}^n \frac{\operatorname{ord}a_i \times \operatorname{ord}a_j}{\gcd(\operatorname{ord}a_i,\operatorname{ord}a_j)^2}

然后现场赛到这里我就不会了(雾)

实际上,我们可以按莫比乌斯反演的套路想到,枚举 gcd。我们考虑到 \operatorname{ord}x 也一定是 p-1 的倍数,所以考虑这样的一个卷积

\begin{aligned} c_k &= \sum_{\gcd(i, j) = k} ijf_if_j \\ &=\sum_{i,j} if_i ~ jf_j [\gcd(i,j)=k] \end{aligned}

其中 f_i 表示 \operatorname{ord} a_x = i 的个数,那么考虑枚举 \gcd 的倍数,则有

\begin{aligned} {c’}_k &= \sum_{x}c_x[k|x] \\ &= \sum_{i,j} if_i ~ jf_j [k|\gcd(i,j)] \\ &= \sum_{i,j} if_i ~ jf_j [k|i][k|j] \\ &= \sum_{i} if_i [k|i] \sum_{j} jf_j [k|j] \end{aligned}

也就是需要作一个类似于

{c’}_k = \sum_{x} c_x[k|x]

的正变换和逆变换。考虑到 FWT 中那个类似于按位与卷积的东西,设计一个类似于高位前缀和的东西,就可以做了。

另外求 \operatorname{ord}x 也有一点注意。朴素的想法是先将 x 拉入模 p 非零元素乘法群的一个 {p_i}^{e_i} 阶子群中,也就是先设计 x_i = x^{\frac{p-1}{{p_i}^{e_i}}},然后求 x_i 在该子群中的周期。所以可以设计一个时间复杂度为 O(c(p-1) \log(p-1)) 的算法。在此过程中,计算的瓶颈在于 x_i 的计算,耗费了绝大部分的时间复杂度,那么我们可以设计一个分治的算法,将因子集合划分为两部分,前半部分将后半部分素因子子群的周期升阶到 1,后半部分同样操作前半部分,然后继续分治。类似于多项式多点求值的一个想法,时间复杂度降为 O(\log c(p-1) \log(p-1))

目前这份代码是CF上跑的最快的,756ms,用普通求阶时间为 1575ms。

#include <bits/stdc++.h>
using namespace std;
typedef long long lld;
typedef long double flt;
const int MAXN = 2e5+5, S = 10;
int n, w, fc[20], dcn;
lld mod, fp[20], zs[20];

inline lld mul(lld x, lld y, lld z) { lld t = flt(x) * y / z; lld ans = (x * y - t * z) % z; if (ans < 0) ans += z; return ans; }
inline lld mul(lld x, lld y) { lld t = flt(x) * y / mod; lld ans = (x * y - t * mod) % mod; if (ans < 0) ans += mod; return ans; }
inline void addeq(lld &x, lld y) { x = x+y - (x+y>=mod?mod:0); }
inline void muleq(lld &x, lld y) { lld t = flt(x) * y / mod; lld ans = (x * y - t * mod) % mod; if (ans < 0) ans += mod; x = ans; }
inline lld fpow(lld x, lld n, lld mod) { x %= mod; if (n == 1) return x; lld ret = 1; for (; n; n>>=1, x=mul(x,x,mod)) if (n&1) ret=mul(ret,x,mod); return ret; }
inline lld fpow(lld x, lld n) { x %= mod; if (n == 1) return x; lld ret = 1; for (; n; n>>=1, x=mul(x,x,mod)) if (n&1) ret=mul(ret,x,mod); return ret; }

namespace Decompose {
    int tol; lld factor[1000];

    bool millerRabin(lld n, lld base) {
        lld n2 = n-1, s = __builtin_ctzll(n2); n2 >>= s;
        lld t = fpow(base, n2, n);
        if (t == 1 || t == n-1) return true;
        for (s--; s >= 0; s--)
            if ((t=mul(t,t,n))==n-1)
                return true;
        return false;
    }

    bool isPrime(lld n) {
        static lld bases[12] = { 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37 };
        if (n <= 2) return false;
        for (int i = 0; i < 12 && bases[i] < n; i++)
            if (!millerRabin(n, bases[i])) return false;
        return true;
    }

    inline lld f(lld x, lld mod) {
        lld ans = 1; addeq(ans, mul(x, x, mod)); return ans;
    }

    void PollardRho(lld n) {
        if (n == 1) {
            return;
        } else if (isPrime(n)) {
            factor[tol++] = n;
        } else if (n & 1) {
            for (int i = 1; ; i++) {
                lld x = i, y = f(x, n), q = __gcd(y-x, n);
                while (q == 1) {
                    x = f(x, n), y = f(f(y, n), n);
                    q = __gcd((y-x+n)%n, n) % n;
                }
                if (q != 0 && q != n) {
                    PollardRho(q), PollardRho(n/q);
                    return;
                }
            }
        } else {
            while (!(n & 1)) factor[tol++] = 2, n >>= 1;
            PollardRho(n);
        }
    }

    inline void findfac(lld n) {
        tol = w = 0; PollardRho(n); map<lld,int> tj;
        for (int i = 0; i < tol; i++) tj[factor[i]]++;
        for (auto ss : tj) {
            fc[w] = ss.second, fp[w] = ss.first;
            zs[w] = 1; for (int i = 0; i < fc[w]; i++) zs[w] *= fp[w];
            w++;
        }
    }
}

using Decompose::findfac;

lld getOrd(lld a) {
    lld res = mod-1;
    for (int i = 0; i < w; i++) {
        lld qwq = fpow(a, res/=zs[i]);
        while (qwq != 1) qwq = fpow(qwq, fp[i]), res *= fp[i];
    }
    return res;
} // Time complexity: O(c(n)log(n))

struct Node { int l, r; lld val; };

lld getOrdFast(lld a) {
    static Node Q[100000]; int front = 0, rear = 0;
    lld res = 1; Q[rear++] = Node { 0, w-1, a };
    while (front < rear) {
        Node f = Q[front++];
        if (f.l == f.r) {
            for (lld bs = f.val; bs != 1; bs = fpow(bs, fp[f.l]))
                res *= fp[f.l];
        } else if (f.l < f.r) {
            int mid = (f.l+f.r)>>1;
            lld prod1 = 1, prod2 = 1;
            for (int i = f.l; i <= mid; i++) prod2 *= zs[i];
            for (int i = mid+1; i <= f.r; i++) prod1 *= zs[i];
            Q[rear++] = Node { f.l, mid, fpow(f.val, prod1) };
            Q[rear++] = Node { mid+1, f.r, fpow(f.val, prod2) };
        }
    }
    return res;
} // Time complexity: O(log(c(n))log(n))

lld di[MAXN], invdi2[MAXN]; int qp[20];
lld a[MAXN], g[MAXN];

void genFac(int wi, lld fac, int ids) {
    if (wi == w) {
        di[ids] = fac;
        invdi2[ids] = fac==1?1:mod-mod/fac;
        muleq(invdi2[ids], invdi2[ids]);
    } else {
        lld pr = 1;
        for (int i = 0; i <= fc[wi]; i++, pr *= fp[wi])
            genFac(wi+1, fac*pr, ids+i*qp[wi]);
    }
}

void addResult(lld ord)
{
    int id = 0; lld n = ord;

    for (int i = 0; i < w; i++)
    {
        int r = 0;
        while (n % fp[i] == 0) n /= fp[i], r++;
        id += r * qp[i];
    }

    addeq(g[id], ord);
}

void solve()
{
    findfac(mod - 1);

    qp[0] = 1;
    for (int i = 1; i < w; i++)
        qp[i] = qp[i-1] * (fc[i-1] + 1);
    dcn = qp[w-1] * (fc[w-1] + 1);
    genFac(0, 1, 0);

    for (int i = 1; i <= n; i++)
        addResult(getOrdFast(a[i]));
    for (int i = 0; i < w; i++)
        for (int j = dcn-1; j >= 0; j--)
            if (j / qp[i] % (fc[i]+1) != fc[i])
                addeq(g[j], g[j+qp[i]]);
    for (int i = 0; i < dcn; i++)
        muleq(g[i], g[i]);
    for (int i = 0; i < w; i++)
        for (int j = 0; j < dcn; j++)
            if (j / qp[i] % (fc[i]+1) != fc[i])
                addeq(g[j], mod-g[j+qp[i]]);
    lld ans = 0;
    for (int i = 0; i < dcn; i++)
        addeq(ans, mul(g[i], invdi2[i]));
    printf("%lld\n", ans);
}

int main()
{
    srand(time(NULL));
    scanf("%d %lld", &n, &mod);
    for (int i = 1; i <= n; i++)
        scanf("%lld", &a[i]);
    solve();
    return 0;
}

2019 ICPC上海网络赛部分题解

C. Triple

考虑分情况不同做法。当 n \le 1000 时,枚举 A_i, B_j,则根据不等式解出满足条件的 C_k \in [|A_i-B_j|,A_i+B_j]。当 n > 1000,考虑枚举不合法的方案数,即形如 A_i + B_j < C_k 的方案数,不等式左边可以通过FFT取得,然后前缀和再枚举 C_k 容斥掉即可。

大力施加常数优化,AC仅需1256ms啦啦啦~

#include <bits/stdc++.h>
using namespace std;
typedef long long lld;
typedef complex<double> cplx;
int revs[1<<19], A[300010]; cplx wmks[1<<19];
const double PI = acos(-1.0);

void dft(cplx a[], int DFT, int N) {
    int *rev = revs + N;
    for (int i = 0; i < N; i++)
        if (i < revs[N+i]) swap(a[i], a[rev[i]]);
    for (int m = 2, m2 = 1; m <= N; m <<= 1, m2 <<= 1) {
        cplx *wmk = wmks + m + (DFT==-1 ? m2 : 0), u, t;
        for (int k = 0; k < N; k += m)
            for (int j = 0; j < m2; j++)
                t = wmk[j] * a[k+j+m2], u = a[k+j],
                a[k+j] = u + t, a[k+j+m2] = u - t;
    }
    if (DFT == -1) for (int i = 0; i < N; i++) a[i] /= N;
}

int a[100010], b[100010], c[100010];

lld solveWithFFT(int n)
{
    static cplx f[262144], g[262144], h[262144];
    static lld bc[262144], ac[262144], ab[262144];
    int ma = 0, mb = 0, mc = 0, md, len = 1; cplx tmp, tmp2;
    for (int i = 0; i < n; i++) ma = max(ma, a[i]), mb = max(mb, b[i]), mc = max(mc, c[i]);
    md = max(ma+mb, max(mb+mc, mc+ma)) + 2; while (len < md) len <<= 1;
    for (int i = 0; i < len; i++) f[i] = g[i] = h[i] = 0;
    for (int i = 0; i < n; i++) f[a[i]] += 1, g[b[i]] += 1, h[c[i]] += 1;
    dft(f, 1, len), dft(g, 1, len), dft(h, 1, len);
    for (int i = 0; i < len; i++) tmp = f[i] * g[i], tmp2 = g[i] * h[i], g[i] = f[i] * h[i], f[i] = tmp2, h[i] = tmp;
    dft(f, -1, len), dft(g, -1, len), dft(h, -1, len);
    for (int i = 0; i < len; i++) bc[i] = f[i].real()+0.2, ac[i] = g[i].real()+0.2, ab[i] = h[i].real()+0.2;
    for (int i = 1; i < len; i++) bc[i] += bc[i-1], ac[i] += ac[i-1], ab[i] += ab[i-1];
    lld ans = 1LL * n * n * n;

    for (int i = 0; i < n; i++)
    {
        if (c[i] <= len) ans -= ab[c[i]-1];
        if (b[i] <= len) ans -= ac[b[i]-1];
        if (a[i] <= len) ans -= bc[a[i]-1];
    }

    return ans;
}

inline int abs(int x) { return x < 0 ? -x : x; }

lld solveWithBruteForce(int n)
{
    for (int i = 0; i <= 200001; i++) A[i] = 0;
    for (int i = 0; i < n; i++) A[c[i]]++;
    for (int i = 1; i <= 200001; i++) A[i] += A[i-1];
    lld ans = 0;
    for (int i = 0; i < n; i++)
        for (int j = 0; j < n; j++)
            ans += A[a[i]+b[j]] - A[max(abs(a[i]-b[j])-1,0)];
    return ans;
}

void solve(int cas)
{
    int n; scanf("%d", &n);
    for (int i = 0; i < n; i++) scanf("%d", &a[i]);
    for (int i = 0; i < n; i++) scanf("%d", &b[i]);
    for (int i = 0; i < n; i++) scanf("%d", &c[i]);
    printf("Case #%d: %lld\n", cas, n <= 1000 ? solveWithBruteForce(n) : solveWithFFT(n));
}

int main()
{
    for (int N = 2, N2 = 1; N <= (1<<18); N <<= 1, N2 <<= 1)
    {
        int *rev = revs + N; cplx *wmk = wmks + N;
        for (int i = 0; i < N; i++)
            rev[i] = (rev[i>>1]>>1)|((i&1)?N2:0);
        for (int i = 0; i < N2; i++)
            wmk[i] = cplx(cos(PI*i/N2), sin(PI*i/N2));
        for (int i = 0; i < N2; i++)
            wmk[i+N2] = cplx(wmk[i].real(), -wmk[i].imag());
    }

    int T; scanf("%d", &T);
    for (int i = 1; i <= T; i++) solve(i);
    return 0;
}

D. Counting Sequences I

首先发现是几个本质不同的方案的排列,于是想到打表。需要一个快速出结果的剪枝。

#include <bits/stdc++.h>
using namespace std;
typedef long long lld;
const int MOD = 1e9+7;
int n, jspx[3010], val[3010];
lld fac[3010], inv[3010], invs[3010], ans;

void dfs(int cur, int last, lld mul, int sum)
{
    if (cur == n)
    {
        if (mul == sum)
        {
            for (int i = 1; i <= n; i++)
                if (jspx[i])
                    printf("%dx%d ", jspx[i], i);
            printf("\n");
            lld tot = fac[n];
            for (int i = 1; i <= n; i++)
                tot = tot * invs[jspx[i]] % MOD;
            ans = (ans + tot) % MOD;
        }
    }
    else if (mul * pow(last, n-cur) <= sum + last * (n - cur))
    {
        for (int i = last; i <= n; i++)
        {
            val[cur] = i;
            jspx[i]++;
            dfs(cur + 1, i, mul * i, sum + i);
            jspx[i]--;
        }
    }
}

int main()
{
    fac[0] = fac[1] = invs[0] = invs[1] = inv[1] = 1;
    for (int i = 2; i < 3010; i++)
        fac[i] = i * fac[i-1] % MOD,
        inv[i] = (MOD-MOD/i) * inv[MOD%i] % MOD,
        invs[i] = invs[i-1] * inv[i] % MOD;
    scanf("%d", &n);
    dfs(0, 1, 1, 0);
    printf("ANS = %lld\n", ans);
    return 0;
}

E. Counting Sequences II

不知道题解在写什么鬼东西……

考虑 [1,m] 中的每个数字 t 的排列型生成函数。如果 t 为奇数,则其生成函数为 f(x) = 1 + \frac{x}{1!} + \frac{x^2}{2!} + \cdots = e^x;否则只能出现偶数次,则为 g(x) = 1 + \frac{x^2}{2!} + \frac{x^4}{4!} + \cdots = \frac{e^x + e^{-x}}{2} = \ch(x)

那么答案相当于求

\left(\frac{e^x+e^{-x}}{2}\right)^{\left\lfloor \frac{m}{2} \right\rfloor} \left(e^x\right)^{\left\lceil \frac{m}{2} \right\rceil}

\frac{x^n}{n!} 项系数。注意到 \left\lceil \frac{m}{2} \right\rceil = \left\lfloor \frac{m}{2} \right\rfloor + (m \mod 2),令 k = \left\lfloor \frac{m}{2} \right\rfloor,则可以化简为

\frac{1}{2^k} \sum_{s=0}^k C_{k}^{s} e^{(2s+(m\mod 2))x}

即最后答案为

\frac{1}{2^k} \sum_{s=0}^k \frac{k! (2s+(m\mod 2))^n}{s! (k-s)!}

#include <bits/stdc++.h>
using namespace std;
typedef long long lld;
const int MOD = 1e9+7, MAXN = 2e5+5;
int fac[MAXN], inv[MAXN], invs[MAXN];

int fpow(lld a, lld k)
{
    lld b = 1; k %= MOD-1;
    for (; k; k >>= 1, a = a * a % MOD)
        if (k & 1) b = b * a % MOD;
    return b;
}

void solve()
{
    lld ans = 0, n; int m, k;
    scanf("%lld %d", &n, &m); k = m / 2;
    for (int s = 0; s <= k; s++)
        ans += 1ll * fac[k] * invs[s] % MOD * invs[k-s] % MOD * fpow(2*s+(m&1), n) % MOD;
    ans = ans % MOD * fpow((MOD+1)/2, k) % MOD;
    printf("%lld\n", ans);
}

int main()
{
    fac[0] = fac[1] = inv[1] = 1;
    invs[0] = invs[1] = 1;
    for (int i = 2; i < MAXN; i++)
        fac[i] = 1ll * i * fac[i-1] % MOD,
        inv[i] = 1ll * (MOD-MOD/i) * inv[MOD%i] % MOD,
        invs[i] = 1ll * invs[i-1] * inv[i] % MOD;
    int T; scanf("%d", &T);
    while (T--) solve();
    return 0;
}

H. Luhhy’s Matrix

考虑让每个矩阵只被运算一次,那么可以使用类似于分块的想法设立分界线。

#include <bits/stdc++.h>
using namespace std;
typedef long long lld;
const int MAXN = 2e5+5;
unsigned lastans, seed, pow17[16], pow19[16];
char cost[MAXN];

struct matrix
{
    char item[16][16];
    int col[16], row[16];
    // let row[k] = item[i][k], col[k] = item[k][i]

    void norm()
    {
        for (int k = 0; k < 16; k++)
        {
            row[k] = col[k] = 0;
            for (int i = 0; i < 16; i++)
                row[k] = (row[k] << 1) | (item[i][k] & 1),
                col[k] = (col[k] << 1) | (item[k][i] & 1);
        }
    }

    matrix operator*(const matrix &b) const
    {
        matrix ans;
        for (int i = 0; i < 16; i++)
            for (int j = 0; j < 16; j++)
                ans.item[i][j] = cost[col[i]&b.row[j]];
        ans.norm();
        return ans;
    }

    static matrix getOne()
    {
        matrix ans; memset(ans.item, 0, sizeof(ans.item));
        for (int i = 0; i < 16; i++) ans.item[i][i] = 1;
        ans.norm(); return ans;
    }

    unsigned getAns() const
    {
        unsigned ans = 0;
        for (int i = 0; i < 16; i++)
            for (int j = 0; j < 16; j++)
                ans += item[i][j] * pow17[i] * pow19[j];
        return ans;
    }

    void output() const
    {
        for (int i = 0; i < 16; i++)
            for (int j = 0; j < 16; j++)
                printf("%d%c", int(item[i][j]), " \n"[j==15]);
    }

    static matrix generate()
    {
        matrix ans;
        seed ^= lastans;
        for (int i = 0; i < 16; i++)
        {
            seed ^= seed * seed + 15;
            for (int j = 0; j < 16; j++)
                ans.item[i][j] = (seed >> j) & 1;
        }
        ans.norm();
        return ans;
    }
} Q[MAXN], tail;

void solve()
{
    int front = 0, rear = 0, cut = 0, n; lastans = 0;
    scanf("%d", &n); tail = matrix::getOne();

    while (n--)
    {
        int t; scanf("%d %u", &t, &seed);

        if (t == 1)
        {
            Q[rear++] = matrix::generate();
            if (rear == front+1) tail = matrix::getOne(), cut = rear;
            else tail = Q[rear-1] * tail;
        }
        else
        {
            front++;

            if (front == cut)
            {
                cut = rear;
                for (int i = rear-2; i >= front; i--) Q[i] = Q[i+1] * Q[i];
                tail = matrix::getOne();
            }
        }

        lastans = (front == rear) ? 0 : (tail * Q[front]).getAns();
        printf("%u\n", lastans);
    }
}

int main()
{
    for (int sz2 = 1; sz2 < 65536; sz2 <<= 1)
        for (int i = 0; i < sz2; i++)
            cost[sz2+i] = cost[i] ^ 1;
    pow17[0] = pow19[0] = 1;
    for (int i = 1; i < 16; i++)
        pow17[i] = pow17[i-1] * 17,
        pow19[i] = pow19[i-1] * 19;
    int T; scanf("%d", &T);
    while (T--) solve();
    return 0;
}

K. Peekaboo

圆上整点。直接分解 a,b 然后暴力枚举点对。看题解以为卡了64倍常数,结果交了就过了(大雾)

#include <bits/stdc++.h>
using namespace std;
typedef long long lld;
typedef pair<int,int> pear;
#define F first
#define S second
inline lld sqr(int a) { return 1ll * a * a; }
lld gcd(lld a, lld b) { return b ? gcd(b, a%b) : a; }

void gaussInteger(lld R, vector<pear> &tot)
{
    auto check = [&tot] (lld r, lld rr) -> void
    {
        if (r == 1 || r % 4 != 1) return;
        for (int n = 1; n*n*2 < r; n++)
        {
            int m = int(sqrt(r-n*n)+1e-5);
            if (m * m + n * n != r) continue;
            if (gcd(m,n) != 1 || m <= n) continue;
            tot.emplace_back(rr*(m*m-n*n), 2*n*m*rr);
        }
    };

    for (int i = 1; i * i <= R; i++)
    {
        if (R % i) continue;
        check(R/i, i);
        if (i * i != R) check(i, R/i);
    }
}

void fillOver(vector<pear> &proc, int r)
{
    int cnt = proc.size();
    proc.resize(cnt*8+4);
    for (int i = 0; i < cnt; i++)
        proc[i+cnt] = pear(proc[i].S, proc[i].F);
    for (int i = 0; i < 2*cnt; i++)
        proc[i+cnt*2] = pear(proc[i].F, -proc[i].S),
        proc[i+cnt*4] = pear(-proc[i].F, proc[i].S),
        proc[i+cnt*6] = pear(-proc[i].F, -proc[i].S);
    proc[cnt*8] = pear(0, r);
    proc[cnt*8+1] = pear(r, 0);
    proc[cnt*8+2] = pear(0, -r);
    proc[cnt*8+3] = pear(-r, 0);
}

pair<pear,pear> ans[100050];

void solve()
{
    vector<pear> possA, possB;
    int a, b, c, tot = 0;
    scanf("%d %d %d", &a, &b, &c);
    gaussInteger(a, possA), gaussInteger(b, possB);
    fillOver(possA, a), fillOver(possB, b);

    for (auto A : possA) for (auto B : possB)
    {
        lld r2 = sqr(c), r3 = sqr(A.F-B.F)+sqr(A.S-B.S);
        if (r2 == r3) ans[tot++] = make_pair(A, B);
    }

    sort(ans, ans+tot);
    printf("%d\n", tot);
    for (int i = 0; i < tot; i++) printf("%d %d %d %d\n", ans[i].F.F, ans[i].F.S, ans[i].S.F, ans[i].S.S);
}

int main()
{
    int T; scanf("%d", &T);
    while (T--) solve();
    return 0;
}

HDU6414 带劲的多项式

考虑形如

f(x) = \prod_{i=1}^r (x-\lambda_i)^{l_i}

形式的多项式,它的一阶导函数为

f'(x) = \left( \sum_{i=1}^{r}l_i(x-\lambda_i) \right) \left(\prod_{i=1}^r (x-\lambda_i)^{l_i -1}\right)

则可以发现它们之间有

\gcd(f(x), f'(x)) = \prod_{i=1}^r (x-\lambda_i)^{l_i-1}

则有

\frac{f(x)}{\gcd(f(x), f'(x))} = \prod_{i=1}^r (x-\lambda_i)

由于根重数不同,每个根的因子会在不同时刻从上式消失。每次将 \gcd(f, f’) 当作新的 f 进行迭代即可。

另外,STL好慢啊233。

#include <bits/stdc++.h>
using namespace std;
typedef long long lld;
const int MOD = 998244353;
typedef vector<int> Poly;

inline lld fpow(lld a, int k)
{
    lld b = 1;
    for (; k; k>>=1, a=a*a%MOD)
        if (k&1) b=b*a%MOD;
    return b;
}

void norm(Poly &a)
{
    lld qwq = fpow(a.back(), MOD-2);
    for (auto &x : a) x = x * qwq % MOD;
}

Poly diff(const Poly &a)
{
    Poly _a = Poly(a.size() - 1);
    for (int i = 0; i < _a.size(); i++)
        _a[i] = a[i+1] * (i+1LL) % MOD;
    return _a;
}

void divide(const Poly &a, const Poly &p, Poly &q, Poly &r)
{
    assert(p.size() >= 1 && p.back() != 0);
    q.clear(), r = a; int P = p.size();
    lld pq = fpow(p.back(), MOD-2);

    for (int i = a.size()-1; i >= P-1; i--)
    {
        q.push_back(r[i] * pq % MOD);
        for (int j = 0; j < P; j++)
            r[i-P+1+j] = (r[i-P+1+j] + MOD - 1LL * r[i] * pq % MOD * p[j] % MOD) % MOD;
    }

    reverse(q.begin(), q.end());
    while (q.back() == 0) q.pop_back();
    while (r.back() == 0) r.pop_back();
}

Poly gcd(Poly a, Poly b)
{
    Poly c, d;
    while (b.size()) divide(a, b, c, d), a = b, b = d;
    norm(a);
    return a;
}

void solve()
{
    int n; scanf("%d", &n); vector<pair<int,int>> ans;
    Poly orig(n+1, 0), dif, cbs, tmp, nil, nul;
    for (int i = 0; i <= n; i++) scanf("%d", &orig[i]);
    norm(orig); dif = diff(orig), tmp = gcd(orig, dif);
    divide(orig, tmp, cbs, nil); orig = tmp, norm(orig);

    for (int l = 1; dif.size(); l++)
    {
        dif = diff(orig), tmp = gcd(orig, dif);
        divide(orig, tmp, nul, nil);

        if (nul != cbs)
        {
            // an root appears!
            Poly a, b;
            divide(cbs, nul, b, a);
            assert(b.size() == 2);
            ans.emplace_back((MOD-b[0])%MOD, l);
            cbs = nul;
        }

        orig = tmp, norm(orig);
    }

    printf("%d\n", int(ans.size()));
    for (auto p : ans) printf("%d %d\n", p.first, p.second);
}

int main()
{
    int T; scanf("%d", &T);
    while (T--) solve();
    return 0;
}