D^3CTF 2025

Misc

d3image

Challenge

我一定是训练模型训练出了幻觉,怎么从这张图里看出了 “不存在” 的文字?

Sloution

为了还原 mysterious_invitation.png 中隐藏的信息,我们需要实现 your_decode_net 函数。根据编码过程,d3net 是一个可逆网络,它将 DWT 变换后的封面图像和秘密信息合并后进行转换。因此,your_decode_net 实际上是 d3net 的逆过程。

具体步骤如下:

  1. 实现 INV_block 的逆操作 INV_block_reverse INV_blockd3net 的基本组成单元。我们需要根据其前向传播的数学关系,推导出反向传播以恢复原始输入。
  2. 实现 D3net 的逆操作 D3net_reverse D3net 由多个 INV_block 串联组成。其逆操作就是将 INV_block_reverse 按相反的顺序串联起来。
  3. decode 函数中使用 D3net_reverse
    • 将待解码的图片进行 DWT 变换。
    • 构建 D3net_reverse 的输入。由于 d3net 的前向传播是 (cover_dwt, payload_dwt) -> (stego_dwt, z_channels),那么其逆向传播就是 (stego_dwt, z_prior) -> (recovered_cover_dwt, recovered_payload_dwt)。这里的 z_prior 通常是一个全零张量,表示编码时被压缩或推向零的隐变量。
    • 运行 D3net_reverse 以获得恢复的秘密信息 DWT。
    • 对恢复的秘密信息 DWT 应用 IWT,还原为原始的位图表示。
    • 最后,将位图转换为文本信息。

下面是修改后的文件内容:

block.py:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
import torch.nn as nn
from utils import initialize_weights

# Dense connection
class ResidualDenseBlock_out(nn.Module):
def __init__(self, bias=True):
super(ResidualDenseBlock_out, self).__init__()
self.channel = 12
self.hidden_size = 32
self.conv1 = nn.Conv2d(self.channel, self.hidden_size, 3, 1, 1, bias=bias)
self.conv2 = nn.Conv2d(self.channel + self.hidden_size, self.hidden_size, 3, 1, 1, bias=bias)
self.conv3 = nn.Conv2d(self.channel + 2 * self.hidden_size, self.hidden_size, 3, 1, 1, bias=bias)
self.conv4 = nn.Conv2d(self.channel + 3 * self.hidden_size, self.hidden_size, 3, 1, 1, bias=bias)
self.conv5 = nn.Conv2d(self.channel + 4 * self.hidden_size, self.channel, 3, 1, 1, bias=bias)
self.lrelu = nn.LeakyReLU(inplace=True)
# initialization
initialize_weights([self.conv5], 0.)

def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5

class INV_block(nn.Module):
def __init__(self, clamp=2.0):
super().__init__()

self.channels = 3
self.clamp = clamp
# ρ
self.r = ResidualDenseBlock_out()
# η
self.y = ResidualDenseBlock_out()
# φ
self.f = ResidualDenseBlock_out()

def e(self, s):
return torch.exp(self.clamp * 2 * (torch.sigmoid(s) - 0.5))

def forward(self, x):
x1, x2 = (x.narrow(1, 0, self.channels*4),
x.narrow(1, self.channels*4, self.channels*4))

t2 = self.f(x2)
y1 = x1 + t2
s1, t1 = self.r(y1), self.y(y1)
y2 = self.e(s1) * x2 + t1

return torch.cat((y1, y2), 1)

# Added for inverse operation
class INV_block_reverse(nn.Module):
def __init__(self, inv_block_instance):
super().__init__()
# Store references to the original block's sub-modules
# This is critical to use the SAME trained weights
self.r = inv_block_instance.r
self.y = inv_block_instance.y
self.f = inv_block_instance.f

self.channels = inv_block_instance.channels
self.clamp = inv_block_instance.clamp

def e(self, s):
return torch.exp(self.clamp * 2 * (torch.sigmoid(s) - 0.5))

def forward(self, y_cat):
# y_cat is torch.cat((y1, y2), 1)
y1, y2 = (y_cat.narrow(1, 0, self.channels*4),
y_cat.narrow(1, self.channels*4, self.channels*4))

# Inverse operations based on INV_block.forward:
# Original:
# t2 = self.f(x2)
# y1 = x1 + t2 => x1 = y1 - t2
# s1, t1 = self.r(y1), self.y(y1)
# y2 = self.e(s1) * x2 + t1 => x2 = (y2 - t1) / self.e(s1)

# Reversing order:
# 1. Calculate s1 and t1 using y1
s1 = self.r(y1)
t1 = self.y(y1)

# 2. Calculate x2 using y2, t1, and s1
e_s1 = self.e(s1)
x2 = (y2 - t1) / e_s1

# 3. Calculate t2 using x2
t2 = self.f(x2)

# 4. Calculate x1 using y1 and t2
x1 = y1 - t2

return torch.cat((x1, x2), 1)

utils.py:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import torch.nn as nn
import torch.nn.init as init
import torch
import numpy as np
import math
from reedsolo import RSCodec
import zlib

rs = RSCodec(128)

def initialize_weights(net_l, scale=1):
if not isinstance(net_l, list):
net_l = [net_l]
for net in net_l:
for m in net.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
m.weight.data *= scale # for residual block
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
m.weight.data *= scale
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias.data, 0.0)

class IWT(nn.Module):
def __init__(self):
super(IWT, self).__init__()
self.requires_grad = False

def forward(self, x):
r = 2
in_batch, in_channel, in_height, in_width = x.size()
#print([in_batch, in_channel, in_height, in_width])
out_batch, out_channel, out_height, out_width = in_batch, int(
in_channel / (r ** 2)), r * in_height, r * in_width
x1 = x[:, 0:out_channel, :, :] / 2
x2 = x[:, out_channel:out_channel * 2, :, :] / 2
x3 = x[:, out_channel * 2:out_channel * 3, :, :] / 2
x4 = x[:, out_channel * 3:out_channel * 4, :, :] / 2


h = torch.zeros([out_batch, out_channel, out_height, out_width]).float().cuda()

h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4
h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4
h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4
h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4

return h
class DWT(nn.Module):
def __init__(self):
super(DWT, self).__init__()
self.requires_grad = False

def forward(self, x):
x01 = x[:, :, 0::2, :] / 2
x02 = x[:, :, 1::2, :] / 2
x1 = x01[:, :, :, 0::2]
x2 = x02[:, :, :, 0::2]
x3 = x01[:, :, :, 1::2]
x4 = x02[:, :, :, 1::2]
x_LL = x1 + x2 + x3 + x4
x_HL = -x1 - x2 + x3 + x4
x_LH = -x1 + x2 - x3 + x4
x_HH = x1 - x2 - x3 + x4
return torch.cat((x_LL, x_HL, x_LH, x_HH), 1)

def random_data(cover,device):
return torch.zeros(cover.size(), device=device).random_(0, 2)

def auxiliary_variable(shape):
noise = torch.zeros(shape).cuda()
for i in range(noise.shape[0]):
noise[i] = torch.randn(noise[i].shape).cuda()

return noise

def computePSNR(origin,pred):
origin = np.array(origin)
origin = origin.astype(np.float32)
pred = np.array(pred)
pred = pred.astype(np.float32)
mse = np.mean((origin/1.0 - pred/1.0) ** 2 )
if mse < 1.0e-10:
return 100
return 10 * math.log10(255.0**2/mse)

def make_payload(width, height, depth, text, batch = 1):
message = text_to_bits(text) + [0] * 32

payload = message
while len(payload) < batch * width * height * depth:
payload += message


payload = payload[:batch * width * height * depth]
return torch.FloatTensor(payload).view(batch, depth, height, width)

def text_to_bits(text):
return bytearray_to_bits(text_to_bytearray(text))

def bytearray_to_bits(x):
result = []
for i in x:
bits = bin(i)[2:]
bits = '00000000'[len(bits):] + bits
result.extend([int(b) for b in bits])

return result

def text_to_bytearray(text):
assert isinstance(text, str), "expected a string"
x = zlib.compress(text.encode("utf-8"))
x = rs.encode(bytearray(x))

return x

def bits_to_bytearray(bits):
ints = []
bits = np.array(bits)
bits = 0 + bits
bits = bits = bits.tolist()
for b in range(len(bits) // 8):
byte = bits[b * 8:(b + 1) * 8]
ints.append(int(''.join([str(bit) for bit in byte]), 2))
return bytearray(ints)

def bytearray_to_text(x):
try:
text = rs.decode(x)
text = zlib.decompress(text[0])

return text.decode("utf-8")
except BaseException:
return False

d3net.py:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from model import *
from block import INV_block, INV_block_reverse # Import INV_block_reverse

class D3net(nn.Module):

def __init__(self):
super(D3net, self).__init__()
self.inv1 = INV_block()
self.inv2 = INV_block()
self.inv3 = INV_block()
self.inv4 = INV_block()
self.inv5 = INV_block()
self.inv6 = INV_block()
self.inv7 = INV_block()
self.inv8 = INV_block()

def forward(self, x):

out = self.inv1(x)
out = self.inv2(out)
out = self.inv3(out)
out = self.inv4(out)
out = self.inv5(out)
out = self.inv6(out)
out = self.inv7(out)
out = self.inv8(out)
return out

# Added for inverse operation
class D3net_reverse(nn.Module):
def __init__(self, original_d3net_instance):
super().__init__()
self.inv_blocks_rev = nn.ModuleList()
# Iterate through original blocks in reverse order
# The original D3net has inv1 to inv8. So, index from 7 down to 0.
for i in range(7, -1, -1): # From inv8 down to inv1
original_inv_block = getattr(original_d3net_instance, f'inv{i+1}')
self.inv_blocks_rev.append(INV_block_reverse(original_inv_block))

def forward(self, y_cat):
# y_cat is the output of the forward pass of original D3net
# which is (stego_dwt, z_channels)
out = y_cat
for inv_block_rev in self.inv_blocks_rev:
out = inv_block_rev(out)
# The final 'out' should be (recovered_cover_dwt, recovered_payload_dwt)
return out

model.py:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch.nn as nn
import torch
from d3net import D3net


class Model(nn.Module):
def __init__(self,cuda=True):
super(Model, self).__init__()
self.model = D3net()
if cuda:
self.model.cuda()
# init_model(self) # This is commented out, so it won't affect loading pretrained weights

def forward(self, x):
out = self.model(x)
return out


def init_model(mod):
for key, param in mod.named_parameters():
split = key.split('.')
if param.requires_grad:
param.data = 0.01 * torch.randn(param.data.shape).cuda()
if split[-2] == 'conv5':
param.data.fill_(0.)

test.py:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
from model import Model
from utils import DWT, IWT, make_payload, auxiliary_variable, bits_to_bytearray, bytearray_to_text
import torchvision
from collections import Counter
from PIL import Image
import torchvision.transforms as T

# Import the reverse D3net
from d3net import D3net_reverse

transform_test = T.Compose([
T.CenterCrop((720,1280)),
T.ToTensor(),
])

def load(name):
state_dicts = torch.load(name)
network_state_dict = {k:v for k,v in state_dicts['net'].items() if 'tmp_var' not in k}
d3net.load_state_dict(network_state_dict)

def transform2tensor(img):
img = Image.open(img)
img = img.convert('RGB')
return transform_test(img).unsqueeze(0).to(device)

def encode(cover, text):
cover = transform2tensor(cover)
B, C, H, W = cover.size()
payload = make_payload(W, H, C, text, B)
payload = payload.to(device)
cover_input = dwt(cover)
payload_input = dwt(payload)
input_img = torch.cat([cover_input, payload_input], dim=1)

output = d3net(input_img)

output_steg = output.narrow(1, 0, 4 * 3)
output_img = iwt(output_steg)
# torchvision.utils.save_image(cover, f'./{text}.png')
torchvision.utils.save_image(output_img,f'./steg.png')


def decode(steg_path):
steg_tensor = transform2tensor(steg_path)
stego_dwt = dwt(steg_tensor) # This is y1, 12 channels (B, 12, H/2, W/2)

B, C, H, W = stego_dwt.size() # C is 12 (number of channels after DWT, i.e., 4*original_channels)

# Create the 'z_prior' part (y2) for the inverse model.
# In many invertible neural networks, the second part of the output (z_channels)
# is trained to follow a simple distribution (e.g., standard normal or zero-mean).
# For decoding, we feed the known stego_dwt (y1) and a sample from this prior (y2).
# A common and simple choice for z_prior is a zero tensor if the model is designed
# to push these latent channels towards zero.
z_prior = torch.zeros(B, C, H, W).to(device)

# Concatenate stego_dwt (y1) and z_prior (y2) to form the input to D3net_reverse.
# The input to the inverse network should have 24 channels (12 for y1, 12 for y2),
# matching the output of the forward D3net.
input_to_reverse = torch.cat((stego_dwt, z_prior), 1) # Total 24 channels

# Instantiate the decoder network using the original D3net instance.
# `d3net` in `__main__` is an instance of `Model`.
# `d3net.model` is the actual `D3net` instance that holds the trained weights.
your_decode_net_instance = D3net_reverse(d3net.model)
your_decode_net_instance.eval() # Set to evaluation mode
your_decode_net_instance.to(device) # Move to device

# Run the inverse model.
# The output will be (recovered_cover_dwt, recovered_payload_dwt).
# This output also has 24 channels.
recovered_channels = your_decode_net_instance(input_to_reverse)

# Extract the recovered payload DWT.
# The original input to the forward D3net was (cover_input, payload_input), both 12 channels.
# So, the second 12 channels of `recovered_channels` correspond to the payload.
# `4 * 3` means 12 channels. We narrow from channel index 12 for 12 channels.
secret_dwt = recovered_channels.narrow(1, 4 * 3, 4 * 3) # Channels 12 to 23 (inclusive), 12 channels total

# Apply IWT to get the raw secret (back to 3 channels image representation).
secret_rev = iwt(secret_dwt)

# The rest of the decode function (from the original problem statement)
# Reshape and convert to boolean bits.
image = secret_rev.view(-1) > 0 # Convert to boolean tensor (torch.bool)

candidates = Counter()
# Convert boolean tensor to list of integers (0 or 1).
bits = image.data.int().cpu().numpy().tolist()

# The `make_payload` function adds `[0] * 32` as a delimiter.
# This translates to 4 zero bytes (`b'\x00\x00\x00\x00'`) after RS encoding and compression.
for candidate in bits_to_bytearray(bits).split(b'\x00\x00\x00\x00'):
candidate = bytearray_to_text(bytearray(candidate))
if candidate:
candidates[candidate] += 1
if len(candidates) == 0:
raise ValueError('Failed to find message.')
candidate, count = candidates.most_common(1)[0]
print(candidate)


if __name__ == '__main__':
d3net = Model()
load('magic.potions')
d3net.eval()

dwt = DWT()
iwt = IWT()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

text = r'd3ctf{Getting that model to converge felt like pure sorcery}'
steg = r'./steg.png'
cover = './poster.png'
# encode(cover, text) # This line is commented out to prevent re-encoding.
decode(steg) # Call decode with the stego image.