1 : // Copyright 2012 Google Inc. All Rights Reserved.
2 : //
3 : // Licensed under the Apache License, Version 2.0 (the "License");
4 : // you may not use this file except in compliance with the License.
5 : // You may obtain a copy of the License at
6 : //
7 : // http://www.apache.org/licenses/LICENSE-2.0
8 : //
9 : // Unless required by applicable law or agreed to in writing, software
10 : // distributed under the License is distributed on an "AS IS" BASIS,
11 : // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 : // See the License for the specific language governing permissions and
13 : // limitations under the License.
14 :
15 : #include "syzygy/pe/transforms/add_imports_transform.h"
16 :
17 : #include "gtest/gtest.h"
18 : #include "syzygy/core/unittest_util.h"
19 : #include "syzygy/pe/decomposer.h"
20 : #include "syzygy/pe/pe_utils.h"
21 : #include "syzygy/pe/unittest_util.h"
22 :
23 : namespace pe {
24 : namespace transforms {
25 :
26 : using block_graph::BlockGraph;
27 : using core::RelativeAddress;
28 : typedef AddImportsTransform::ImportedModule ImportedModule;
29 :
30 : namespace {
31 :
32 : class AddImportsTransformTest : public testing::PELibUnitTest {
33 : public:
34 E : AddImportsTransformTest() : image_layout_(&block_graph_) {
35 E : }
36 :
37 E : virtual void SetUp() {
38 E : FilePath image_path(testing::GetExeRelativePath(testing::kTestDllName));
39 :
40 E : ASSERT_TRUE(pe_file_.Init(image_path));
41 :
42 : // Decompose the test image and look at the result.
43 E : Decomposer decomposer(pe_file_);
44 E : ASSERT_TRUE(decomposer.Decompose(&image_layout_));
45 :
46 : // Retrieve and validate the DOS header.
47 : dos_header_block_ =
48 E : image_layout_.blocks.GetBlockByAddress(RelativeAddress(0));
49 E : ASSERT_TRUE(dos_header_block_ != NULL);
50 E : ASSERT_TRUE(IsValidDosHeaderBlock(dos_header_block_));
51 E : }
52 :
53 : PEFile pe_file_;
54 : BlockGraph block_graph_;
55 : ImageLayout image_layout_;
56 : BlockGraph::Block* dos_header_block_;
57 : };
58 :
59 : // Given an ImportedModule tests that all of its symbols have been properly
60 : // processed.
61 E : void TestSymbols(const ImportedModule& module) {
62 E : for (size_t i = 0; i < module.size(); ++i) {
63 E : BlockGraph::Reference ref;
64 E : EXPECT_TRUE(module.GetSymbolReference(i, &ref));
65 E : EXPECT_TRUE(ref.referenced() != NULL);
66 E : EXPECT_GE(ref.offset(), 0);
67 : EXPECT_LT(ref.offset(),
68 E : static_cast<BlockGraph::Offset>(ref.referenced()->size()));
69 E : }
70 E : }
71 :
72 : } // namespace
73 :
74 E : TEST_F(AddImportsTransformTest, AddImportsExisting) {
75 E : ImportedModule module("export_dll.dll");
76 : size_t function1 = module.AddSymbol("function1",
77 E : ImportedModule::kAlwaysImport);
78 : size_t function3 = module.AddSymbol("function3",
79 E : ImportedModule::kAlwaysImport);
80 E : EXPECT_EQ("function1", module.GetSymbolName(function1));
81 E : EXPECT_EQ("function3", module.GetSymbolName(function3));
82 E : EXPECT_EQ(ImportedModule::kAlwaysImport, module.mode());
83 E : EXPECT_EQ(ImportedModule::kAlwaysImport, module.GetSymbolMode(function1));
84 E : EXPECT_EQ(ImportedModule::kAlwaysImport, module.GetSymbolMode(function3));
85 :
86 E : AddImportsTransform transform;
87 E : transform.AddModule(&module);
88 : EXPECT_TRUE(block_graph::ApplyBlockGraphTransform(
89 E : &transform, &block_graph_, dos_header_block_));
90 E : EXPECT_EQ(0u, transform.modules_added());
91 E : EXPECT_EQ(0u, transform.symbols_added());
92 :
93 E : EXPECT_TRUE(module.ModuleIsImported());
94 E : EXPECT_TRUE(module.SymbolIsImported(function1));
95 E : EXPECT_TRUE(module.SymbolIsImported(function3));
96 :
97 E : EXPECT_FALSE(module.ModuleWasAdded());
98 E : EXPECT_FALSE(module.SymbolWasAdded(function1));
99 E : EXPECT_FALSE(module.SymbolWasAdded(function3));
100 :
101 : EXPECT_NE(ImportedModule::kInvalidIatIndex,
102 E : module.GetSymbolIatIndex(function1));
103 : EXPECT_NE(ImportedModule::kInvalidIatIndex,
104 E : module.GetSymbolIatIndex(function3));
105 :
106 E : EXPECT_NO_FATAL_FAILURE(TestSymbols(module));
107 E : }
108 :
109 E : TEST_F(AddImportsTransformTest, AddImportsNewSymbol) {
110 E : ImportedModule module("export_dll.dll");
111 : size_t function1 = module.AddSymbol("function1",
112 E : ImportedModule::kAlwaysImport);
113 : size_t function3 = module.AddSymbol("function3",
114 E : ImportedModule::kAlwaysImport);
115 : size_t function4 = module.AddSymbol("function4",
116 E : ImportedModule::kAlwaysImport);
117 E : EXPECT_EQ("function1", module.GetSymbolName(function1));
118 E : EXPECT_EQ("function3", module.GetSymbolName(function3));
119 E : EXPECT_EQ("function4", module.GetSymbolName(function4));
120 E : EXPECT_EQ(ImportedModule::kAlwaysImport, module.mode());
121 E : EXPECT_EQ(ImportedModule::kAlwaysImport, module.GetSymbolMode(function1));
122 E : EXPECT_EQ(ImportedModule::kAlwaysImport, module.GetSymbolMode(function3));
123 E : EXPECT_EQ(ImportedModule::kAlwaysImport, module.GetSymbolMode(function4));
124 :
125 E : EXPECT_TRUE(module.import_descriptor().block() == NULL);
126 :
127 E : AddImportsTransform transform;
128 E : transform.AddModule(&module);
129 : EXPECT_TRUE(block_graph::ApplyBlockGraphTransform(
130 E : &transform, &block_graph_, dos_header_block_));
131 E : EXPECT_EQ(0u, transform.modules_added());
132 E : EXPECT_EQ(1u, transform.symbols_added());
133 :
134 E : EXPECT_TRUE(module.import_descriptor().block() != NULL);
135 :
136 E : EXPECT_TRUE(module.ModuleIsImported());
137 E : EXPECT_TRUE(module.SymbolIsImported(function1));
138 E : EXPECT_TRUE(module.SymbolIsImported(function3));
139 E : EXPECT_TRUE(module.SymbolIsImported(function4));
140 :
141 E : EXPECT_FALSE(module.ModuleWasAdded());
142 E : EXPECT_FALSE(module.SymbolWasAdded(function1));
143 E : EXPECT_FALSE(module.SymbolWasAdded(function3));
144 E : EXPECT_TRUE(module.SymbolWasAdded(function4));
145 :
146 : EXPECT_NE(ImportedModule::kInvalidIatIndex,
147 E : module.GetSymbolIatIndex(function1));
148 : EXPECT_NE(ImportedModule::kInvalidIatIndex,
149 E : module.GetSymbolIatIndex(function3));
150 : EXPECT_NE(ImportedModule::kInvalidIatIndex,
151 E : module.GetSymbolIatIndex(function4));
152 :
153 E : EXPECT_NO_FATAL_FAILURE(TestSymbols(module));
154 :
155 : // TODO(chrisha): Write the image and try to load it!
156 E : }
157 :
158 E : TEST_F(AddImportsTransformTest, AddImportsNewModule) {
159 E : ImportedModule module("call_trace_client_rpc.dll");
160 : size_t indirect_penter = module.AddSymbol(
161 E : "_indirect_penter", ImportedModule::kAlwaysImport);
162 : size_t indirect_penter_dllmain = module.AddSymbol(
163 E : "_indirect_penter_dllmain", ImportedModule::kAlwaysImport);
164 : EXPECT_EQ("_indirect_penter",
165 E : module.GetSymbolName(indirect_penter));
166 : EXPECT_EQ("_indirect_penter_dllmain",
167 E : module.GetSymbolName(indirect_penter_dllmain));
168 E : EXPECT_EQ(ImportedModule::kAlwaysImport, module.mode());
169 : EXPECT_EQ(ImportedModule::kAlwaysImport,
170 E : module.GetSymbolMode(indirect_penter));
171 : EXPECT_EQ(ImportedModule::kAlwaysImport,
172 E : module.GetSymbolMode(indirect_penter_dllmain));
173 :
174 E : AddImportsTransform transform;
175 E : transform.AddModule(&module);
176 : EXPECT_TRUE(block_graph::ApplyBlockGraphTransform(
177 E : &transform, &block_graph_, dos_header_block_));
178 E : EXPECT_EQ(1u, transform.modules_added());
179 E : EXPECT_EQ(2u, transform.symbols_added());
180 :
181 E : EXPECT_TRUE(module.ModuleIsImported());
182 E : EXPECT_TRUE(module.SymbolIsImported(indirect_penter));
183 E : EXPECT_TRUE(module.SymbolIsImported(indirect_penter_dllmain));
184 :
185 E : EXPECT_TRUE(module.ModuleWasAdded());
186 E : EXPECT_TRUE(module.SymbolWasAdded(indirect_penter));
187 E : EXPECT_TRUE(module.SymbolWasAdded(indirect_penter_dllmain));
188 :
189 : EXPECT_NE(ImportedModule::kInvalidIatIndex,
190 E : module.GetSymbolIatIndex(indirect_penter));
191 : EXPECT_NE(ImportedModule::kInvalidIatIndex,
192 E : module.GetSymbolIatIndex(indirect_penter_dllmain));
193 :
194 E : EXPECT_NO_FATAL_FAILURE(TestSymbols(module));
195 :
196 : // TODO(chrisha): Write the image and try to load it!
197 E : }
198 :
199 E : TEST_F(AddImportsTransformTest, FindImportsExisting) {
200 E : ImportedModule module("export_dll.dll");
201 E : size_t function1 = module.AddSymbol("function1", ImportedModule::kFindOnly);
202 E : size_t function3 = module.AddSymbol("function3", ImportedModule::kFindOnly);
203 E : EXPECT_EQ("function1", module.GetSymbolName(function1));
204 E : EXPECT_EQ("function3", module.GetSymbolName(function3));
205 E : EXPECT_EQ(ImportedModule::kFindOnly, module.mode());
206 E : EXPECT_EQ(ImportedModule::kFindOnly, module.GetSymbolMode(function1));
207 E : EXPECT_EQ(ImportedModule::kFindOnly, module.GetSymbolMode(function3));
208 :
209 E : AddImportsTransform transform;
210 E : transform.AddModule(&module);
211 : EXPECT_TRUE(block_graph::ApplyBlockGraphTransform(
212 E : &transform, &block_graph_, dos_header_block_));
213 E : EXPECT_EQ(0u, transform.modules_added());
214 E : EXPECT_EQ(0u, transform.symbols_added());
215 :
216 E : EXPECT_TRUE(module.ModuleIsImported());
217 E : EXPECT_TRUE(module.SymbolIsImported(function1));
218 E : EXPECT_TRUE(module.SymbolIsImported(function3));
219 :
220 E : EXPECT_FALSE(module.ModuleWasAdded());
221 E : EXPECT_FALSE(module.SymbolWasAdded(function1));
222 E : EXPECT_FALSE(module.SymbolWasAdded(function3));
223 :
224 : EXPECT_NE(ImportedModule::kInvalidIatIndex,
225 E : module.GetSymbolIatIndex(function1));
226 : EXPECT_NE(ImportedModule::kInvalidIatIndex,
227 E : module.GetSymbolIatIndex(function3));
228 E : }
229 :
230 E : TEST_F(AddImportsTransformTest, FindImportsNewSymbol) {
231 E : ImportedModule module("export_dll.dll");
232 E : size_t function1 = module.AddSymbol("function1", ImportedModule::kFindOnly);
233 E : size_t function3 = module.AddSymbol("function3", ImportedModule::kFindOnly);
234 E : size_t function4 = module.AddSymbol("function4", ImportedModule::kFindOnly);
235 E : EXPECT_EQ("function1", module.GetSymbolName(function1));
236 E : EXPECT_EQ("function3", module.GetSymbolName(function3));
237 E : EXPECT_EQ("function4", module.GetSymbolName(function4));
238 E : EXPECT_EQ(ImportedModule::kFindOnly, module.mode());
239 E : EXPECT_EQ(ImportedModule::kFindOnly, module.GetSymbolMode(function1));
240 E : EXPECT_EQ(ImportedModule::kFindOnly, module.GetSymbolMode(function3));
241 E : EXPECT_EQ(ImportedModule::kFindOnly, module.GetSymbolMode(function4));
242 :
243 E : AddImportsTransform transform;
244 E : transform.AddModule(&module);
245 : EXPECT_TRUE(block_graph::ApplyBlockGraphTransform(
246 E : &transform, &block_graph_, dos_header_block_));
247 E : EXPECT_EQ(0u, transform.modules_added());
248 E : EXPECT_EQ(0u, transform.symbols_added());
249 :
250 E : EXPECT_TRUE(module.ModuleIsImported());
251 E : EXPECT_TRUE(module.SymbolIsImported(function1));
252 E : EXPECT_TRUE(module.SymbolIsImported(function3));
253 E : EXPECT_FALSE(module.SymbolIsImported(function4));
254 :
255 E : EXPECT_FALSE(module.ModuleWasAdded());
256 E : EXPECT_FALSE(module.SymbolWasAdded(function1));
257 E : EXPECT_FALSE(module.SymbolWasAdded(function3));
258 E : EXPECT_FALSE(module.SymbolWasAdded(function4));
259 :
260 : EXPECT_NE(ImportedModule::kInvalidIatIndex,
261 E : module.GetSymbolIatIndex(function1));
262 : EXPECT_NE(ImportedModule::kInvalidIatIndex,
263 E : module.GetSymbolIatIndex(function3));
264 : EXPECT_EQ(ImportedModule::kInvalidIatIndex,
265 E : module.GetSymbolIatIndex(function4));
266 E : }
267 :
268 E : TEST_F(AddImportsTransformTest, FindImportsNewModule) {
269 E : ImportedModule module("call_trace_client_rpc.dll");
270 : size_t indirect_penter = module.AddSymbol(
271 E : "_indirect_penter", ImportedModule::kFindOnly);
272 : size_t indirect_penter_dllmain = module.AddSymbol(
273 E : "_indirect_penter_dllmain", ImportedModule::kFindOnly);
274 : EXPECT_EQ("_indirect_penter",
275 E : module.GetSymbolName(indirect_penter));
276 : EXPECT_EQ("_indirect_penter_dllmain",
277 E : module.GetSymbolName(indirect_penter_dllmain));
278 E : EXPECT_EQ(ImportedModule::kFindOnly, module.mode());
279 E : EXPECT_EQ(ImportedModule::kFindOnly, module.GetSymbolMode(indirect_penter));
280 : EXPECT_EQ(ImportedModule::kFindOnly,
281 E : module.GetSymbolMode(indirect_penter_dllmain));
282 :
283 E : AddImportsTransform transform;
284 E : transform.AddModule(&module);
285 : EXPECT_TRUE(block_graph::ApplyBlockGraphTransform(
286 E : &transform, &block_graph_, dos_header_block_));
287 E : EXPECT_EQ(0u, transform.modules_added());
288 E : EXPECT_EQ(0u, transform.symbols_added());
289 :
290 E : EXPECT_FALSE(module.ModuleIsImported());
291 E : EXPECT_FALSE(module.SymbolIsImported(indirect_penter));
292 E : EXPECT_FALSE(module.SymbolIsImported(indirect_penter_dllmain));
293 :
294 E : EXPECT_FALSE(module.ModuleWasAdded());
295 E : EXPECT_FALSE(module.SymbolWasAdded(indirect_penter));
296 E : EXPECT_FALSE(module.SymbolWasAdded(indirect_penter_dllmain));
297 :
298 : EXPECT_EQ(ImportedModule::kInvalidIatIndex,
299 E : module.GetSymbolIatIndex(indirect_penter));
300 : EXPECT_EQ(ImportedModule::kInvalidIatIndex,
301 E : module.GetSymbolIatIndex(indirect_penter_dllmain));
302 E : }
303 :
304 : } // namespace transforms
305 : } // namespace pe
|