本文对 CTGAN 的原理和实现进行简单复述。论文原文在 arXiv 上,参考的代码版本是 v0.2.2。
数据表示
对连续列,采取模式特定的正则化(mode-specific normalization)。这种方法能够
... convert continuous values of arbitrary range and distribution into a bounded vector representation suitable for neural networks. 1
这是因为
Neural networks can effectively generate values with a distribution centered over (−1, 1) using tanh, as well as a low-cardinality multinomial distribution using softmax. 2
提出这个正则化的目标是把输入映射到 (-1, 1) 的区间上。
使用变分高斯混合模型(variational Gaussian mixture model),原始的分布可通过下述公式拟合
$\displaystyle p(c)=\sum_{k=1}^{K}\pi_kN(c~\vert~\mu_k,\Sigma_k)$
3
其中 $c$ 为某一连续列的某一值。注意到实际一列视为一个变量,协方差矩阵 $\Sigma$ 退化成方差 $\sigma^2$。
对拟合好的 VGM,可以计算任一值 $c$ 来自各个模式(mode,或作 state,component)的概率 $\rho_k$。对各个模式进行一次随机采样,各模式被采样的概率即是各自对应的 $\rho_k$。设模式 $k^*$ 被采样,则用该模式进行正则化。其中
$\alpha=\cfrac{c-\mu_{k^*}}{4~\sigma_{k^*}}$
$\beta$ 为表示 $k^*$ 被采样的 one-hot 向量
注意到极大多数情况下,$c$ 服从被选上的高斯分布,因而这个正则化是合理的。另外,在 TGAN 中,分母中的常数用的是 2。实际上,最后正则化的值会另被钳在 (-0.99, 0.99) 的区间中。
对离散列,则使用 one-hot 向量表示。
最终一行数据可表为
$r=\alpha_1\oplus\beta_1\oplus\cdots\oplus\alpha_{N_c}\oplus\beta_{N_c}\oplus d_1\oplus\cdots\oplus d_{N_d}$
实现细节
先把 train_data
过一遍 DataTransformer
,分 fit
和 transform
两步。
DataTransformer.fit()
fit()
中对数据产生了如下 meta 信息。
- 连续列
|
|
4 其中 components
为数组表示 VGM 中假设的各个分量是否激活,num_components
表示激活的分量总数。
|
|
4 重要的是 weight_concentration_prior
参数。通过指定一个低浓度先验,将会使模型将大部分的权重分配到少数分量上,而其余分量的权重则趋近 0 5。之后调用 BayesianGaussianMixture.fit()
使用 EM 算法估计模型的参数,上述 weight_concentration_prior
即对应 Bishop 书中的 $\gamma$。
- 离散列
|
|
4 即,使用 one-hot 向量表示,categories
为该列的类别(category)数。
上述两种结构都推入 meta
。
DataTransformer.transform()
包含 _transform_continuous()
和 _transform_discrete()
。后者是简单的变为 one-hot 向量,前者实现上有些操作。
data
的形状为 (len, 1),其中 len 为数据集的行数。features
的形状为 (len, n_clusters)。随机选取的时候,随机选取的概率为 Bayesian Gaussian Mixture 预测的概率加上 1e-6 再取概率。最后输出 features
(即 $\alpha$)时,会 np.clip()
到 (-.99, .99) 的区间中。
用正态分布构造示例数据,经 transfrom 得样例输出
|
|
依次为连续列 $C_1$、连续列 $C_2$、离散列 $D_1$、离散列 $D_2$。$C_1$ 中采样到 3 个模式;$C_2$ 中 1 个;$D_1$ 有 3 个类别;$D_2$ 有 2 个。
条件生成器
引入条件生成器(conditional generator)以期望解决离散列的类别不平衡(class imbalance)问题。该生成器的输出是一个 one-hot 形式的条件向量(conditional vector),其长度为各离散列类别数之和,即 $\sum_{i=1}^{N_d}|D_i|$。这个条件向量是随机的,其各个离散列的可能性均等,但列内各类别的可能性与数据集的相同。
条件生成器使得生成器原本要拟合的左端项
$\displaystyle P(\text{row})=\sum_{k\in D_i} P(\text{row}|D_i=k)P(D_i=k)$ 1
变为在给定 $D_i=k$ 的条件下拟合 $P(\text{row}|D_i=k)$,以期小模式也被充分采样。
注意这并非 GAN 意义下的生成器,而更类似于采样器。v0.3.0 版本中与采样器合并。
实现细节
本小节中设数组下标从 1 开始。
记离散列数为 $N_d$,第 $i$ 个离散列 $D_i$ 的类别(category)数为 $|D_i|$。记数据集行数为 len
,批大小为 batch
。
ConditionalGenerator.__init__()
构造函数中,对数据集进行处理。n_col
记录离散列数 $N_d$。n_opt
记录各离散列类别数之和,即 $\sum_{i=1}^{N_d}|D_i|$。model
形状为 (n_col, len)
,其每一列表示数据集对应行对应的离散类别序号;这个变量只在整个模型训练完毕后,生成数据时在不传入条件向量的情况下使用。interval
形状为 (n_col, 2)
,意义为
$\displaystyle \text{interval}[j] = \left( \sum_{k=0}^{j-1} |D_k| , \; |D_{j}| \right) \quad (j=1,2,\cdots,N_d, \; |D_0|=0)$
重要的是计算出 p
,它表征了离散列在数据集中的先验。其形状为 (n_col, max_interval)
,其中 max_interval
表示 $\max|D_i|$。对第 $i$ 行,其第 $1\le j\le|D_i|$ 列表示离散列 $D_i$ 取其第 $j$ 个类别的对数频率(log frequency),余下的列均为 0。
ConditionalGenerator.sample()
通过 sample()
采样。先在 $[1,N_d]$ 的整数中等可能地随机抽取,生成长度为 batch
的序列 idx
,表示 batch
次随机到的离散列的序号。mask1
的形状为 (batch, n_col)
,每行为 one-hot 向量,表示该次随机到的离散列。
另有序列 opt1prime
长度为 batch
,表示从 idx
对应离散列中随机到的类别的(列内的)序号。具体方法是从 $U(0,1)$ 中采样 r
,根据 p
的累积分布函数选择 $\underset{j}{\argmin}\operatorname{CDF}_p(j)>r$ 的类别。
最后 vec1
形状为 (batch, n_opt)
,各行是对应的条件向量 1。
与训练中变量的对应关系:
c1, m1, col, opt = vec1, mask1, idx, opt1prime
4
采样器
采样器用于从数据集中采样。v0.3.0 版本中与条件生成器合并。
实现细节
Sampler.model
为三维列表,第一维长度是离散列数 $N_d$,第二维长度是对应列的类别数 $|D_i|$,第三维存储离散列 $D_i$ 取类别 $j$ 的行的序号。
采样时,Sampler.sample()
根据给定的 col
和 opt
,从满足该条件的行中随机等可能选取一行。
训练过程
生成器
随机产生某定长向量 $z$,使各分量独立地采样自 $N(0,1)$。通过条件生成器采样向量 $cond$,其长度为各离散列的类别数之和。输入网络
$h_0=z\oplus cond$
1
定义
$\operatorname{Residual}_{i\rarr o}(h)=[\operatorname{ReLU}\circ\operatorname{BatchNorm1d}\circ\operatorname{Linear}_{i\rarr o}(h)]\oplus h$
则
$h_1=\operatorname{Residual}_{\lVert h_0\rVert\rarr 256}(h_0)$
$h_2=\operatorname{Residual}_{\lVert h_1\rVert\rarr 256}(h_1)$
$\operatorname{G}(\cdot)=\operatorname{Linear}_{\lVert h_2\rVert\rarr\lVert r\rVert}(h_2)$
生成器输出
$r=\alpha_1\oplus\beta_1\oplus\cdots\oplus\alpha_{N_c}\oplus\beta_{N_c}\oplus d_1\oplus\cdots\oplus d_{N_d}$
再做
$\hat \alpha_i=\tanh(\alpha_i)$
$\hat \beta_i=\operatorname{gumbel}_{0.2}(\beta_i)$
$\hat d_i=\operatorname{gumbel}_{0.2}(d_i)$
得到最终 $\hat r$。
损失函数的定义基于 WGAN
$-E[D(\hat x)]$
6
再加上 $\hat r$ 中的 $\hat d$ 中被 $mask$ 部分的交叉熵
$loss=-E[D(\hat x)]+E[\operatorname{CrossEntropy}_{mask}(\hat d,cond)]$
1
注意实际输入的 $x$ 包括行向量 $r$ 和条件向量 $cond$。
判别器
参考 PacGAN,将 pac 个样本作为一个包(packet),以期防止模式坍缩(mode collapse)。7
$h_0=r_1\oplus cond_1\oplus\cdots\oplus r_{pac}\oplus cond_{pac}$
$h_1=\operatorname{Dropout}_{0.5}\circ\operatorname{LeakyReLU}_{0.2}\circ\operatorname{Linear}_{\lVert h_0\rVert\rarr 256}(h_0)$
$h_2=\operatorname{Dropout}_{0.5}\circ\operatorname{LeakyReLU}_{0.2}\circ\operatorname{Linear}_{\lVert h_1\rVert\rarr 256}(h_1)$
$\operatorname{D}(\cdot)=\operatorname{Linear}_{\lVert h_2\rVert\rarr1}(h_2)$
损失函数的定义参考 WGAN-GP,在 WGAN 原本的损失函数的基础上加入梯度惩罚(gradient penalty)以使训练收敛
$loss=E_{\hat x\sim P_{gen}}[D(\hat x)]-E_{x\sim P_{real}}[D(x)]+\lambda E_{\tilde x\sim P_{pen}}[\lVert\nabla_{\tilde x}D(\tilde x)\rVert_2-1]^2$
8
其中 $\lambda=10$,$P_{pen}$ 采样自 $P_{gen}$ 到 $P_{real}$ 的随机线性内插值。注意实际输入的 $x$ 包括行向量 $r$ 和条件向量 $cond$。
实现细节
fakez
的形状为 (batch_size, embedding_dim)
即 (500, 128),各元素独立地采样自 $N(0,1)$。condvec
采样自条件生成器,四个项分别作 c1, m1, col, opt
。将 c1
接到 fakez
之后。
生成随机置换 perm
,将其作用于 col, opt, c1
。从数据集中采样 batch_size
个符合置换后的 col, opt
的行 real
;将置换后的 c1
(记为 c2
)接在其后,得 real_cat
。
生成器输入长度为 embedding_dim + n_opt
。fake = generator(fakez)
。对 fake
中的对应分量施加函数 _apply_activate()
,再接上 c1
得 fake_cat
。
计算 y_fake = discriminator(fake_cat)
和 y_real = discriminator(real_cat)
。
梯度惩罚的计算中,注意来自同一个包的样本插值量相同。
优化器为 Adam,学习率为 2e-4。
采样/合成
若给出 cond
则将 cond
作为条件输入生成器;若不给定 cond
,则按「各个离散列的可能性均等,但列内各类别的可能性与数据集的相同」的方式产生 cond
。
Lei Xu et al. Modeling tabular data using conditional gan. In Advances in Neural Information Processing Systems, 2019. ↩︎
Lei Xu and Kalyan Veeramachaneni. Synthesizing tabular data using generative adversarial networks. arXiv preprint arXiv:1811.11264, 2018. ↩︎
Christopher M Bishop. Pattern recognition and machine learning. springer, 2006. ↩︎
https://github.com/sdv-dev/CTGAN/tree/v0.2.2 ↩︎
https://scikit-learn.org/stable/modules/mixture.html#bgmm ↩︎
Martin Arjovsky et al. Wasserstein GAN. In International Conference on Machine Learning, 2017. ↩︎
Zinan Lin et al. PacGAN: The power of two samples in generative adversarial networks. In Advances in Neural Information Processing Systems, 2018. ↩︎
Ishaan Gulrajani et al. Improved Training of Wasserstein GANs. In Advances in Neural Information Processing Systems, 2017. ↩︎