Reference Information Extraction:人脸id的保持还是通过使用arcface人脸识别模型来获取embedding,该阶段首先获取feature
F
f
a
c
e
F_{face}
Fface,对于服装发型等,先将人物分割出来,之后使用clip vision encoder提取对应的特征
F
c
h
a
r
a
c
t
e
r
F_{character}
Fcharacter,对应的是图中的facial encoder和portrait encoder;
Reference Information Refinement by Positional-aware Perceiver Resampler:上一部分获取的
F
f
a
c
e
,
F
c
h
a
r
a
c
t
e
r
F_{face},F_{character}
Fface,Fcharacter分别经过独立的Resample模块得到的embedding进行concat,然后和positional embedding进行相加,这里的positional embedding的作用是用来区分不同的人物,此处还加上了背景的embedding
E
b
g
E_{bg}
Ebg,公式化如下:
E
1
=
R
1
(
F
f
a
c
e
)
,
E
2
=
R
2
(
F
c
h
a
r
a
c
t
e
r
)
,
E
i
=
M
L
P
(
C
a
t
(
E
1
,
E
2
)
+
E
p
o
s
)
,
c
i
=
C
a
t
(
E
b
g
,
R
e
s
h
a
p
e
(
E
i
,
(
N
×
L
,
D
)
)
E_1 = R_1(F_{face}),E_2 = R_2(F_{character}),E_i = MLP(Cat(E_1,E_2) + E_{pos}),c_i = Cat(E_{bg},Reshape(E_i,(N \times L,D))
E1=R1(Fface),E2=R2(Fcharacter),Ei=MLP(Cat(E1,E2)+Epos),ci=Cat(Ebg,Reshape(Ei,(N×L,D)),最后的
c
i
c_i
ci就是image cross attention的条件,作为条件的方式和IP adapter一样;
Pose Decoupling from Character Images:对于姿态这一块,为了增强多样性,所以在模型的基础上加上了pose controlnet模块来控制生成人物的姿态,在推理的时候也可以不使用这一块;
Training with LoRA:在训练的时候使用了之前训练模型ip-adapter的参数,模型的cross attention模块可训练的参数是新添加的lora对应的参数,
Q
=
Z
(
W
q
+
Δ
W
q
)
,
K
t
=
c
t
(
W
k
t
+
Δ
W
k
t
)
,
V
t
=
c
t
(
W
v
t
+
Δ
W
v
t
)
,
K
i
=
c
i
(
W
k
i
+
Δ
W
k
i
)
,
V
i
=
c
i
(
W
v
i
+
Δ
W
v
i
)
Q = Z(W_q + \Delta W_q),K_t = c_t(W^t_k + \Delta W_k^t),V_t = c_t(W_v^t + \Delta W_v^t),K_i = c_i(W_k^i + \Delta W^i_k),V_i = c_i(W_v^i + \Delta W^i_v)
Q=Z(Wq+ΔWq),Kt=ct(Wkt+ΔWkt),Vt=ct(Wvt+ΔWvt),Ki=ci(Wki+ΔWki),Vi=ci(Wvi+ΔWvi),其中可以训练的参数是
Δ
\Delta
Δ的部分;
Loss Constraints on Cross-attention Maps with Masks:这一块为了保证各个部分(背景+多character)之间不相互影响,新加入了一个损失,在每一层的cross attention部分,取出image的attention map,,然后将attention在L维度进行拍平,得到的结果和mask计算损失:
P
=
S
o
f
t
m
a
x
(
Q
K
T
/
d
)
,
A
=
∑
k
=
1
L
P
k
P = Softmax(QK^T/\sqrt{d}),A = \sum_{k = 1}^L P_k
P=Softmax(QKT/d),A=∑k=1LPk,损失为
L
a
t
t
n
=
1
N
+
1
∑
k
=
1
N
+
1
∣
∣
A
k
−
M
k
∣
∣
2
2
L_{attn} = \frac{1}{N + 1}\sum_{k = 1}^{N + 1}||A_k - M_k||^2_2
Lattn=N+11∑k=1N+1∣∣Ak−Mk∣∣22,其中N是character的数目,L表示每个character的token数目;