1 : // Copyright 2013 Google Inc. All Rights Reserved.
2 : //
3 : // Licensed under the Apache License, Version 2.0 (the "License");
4 : // you may not use this file except in compliance with the License.
5 : // You may obtain a copy of the License at
6 : //
7 : // http://www.apache.org/licenses/LICENSE-2.0
8 : //
9 : // Unless required by applicable law or agreed to in writing, software
10 : // distributed under the License is distributed on an "AS IS" BASIS,
11 : // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 : // See the License for the specific language governing permissions and
13 : // limitations under the License.
14 :
15 : #include "syzygy/pe/transforms/coff_add_imports_transform.h"
16 :
17 : #include "gtest/gtest.h"
18 : #include "syzygy/block_graph/unittest_util.h"
19 : #include "syzygy/pe/pe_utils.h"
20 : #include "syzygy/pe/unittest_util.h"
21 :
22 : namespace pe {
23 : namespace transforms {
24 :
25 : using block_graph::BlockGraph;
26 :
27 : namespace {
28 :
29 : class CoffAddImportsTransformTest : public testing::CoffUnitTest {
30 : public:
31 E : virtual void SetUp() override { testing::CoffUnitTest::SetUp(); }
32 :
33 : // Check that symbols in @p module have been assigned a reference, and that
34 : // they pass through a round-trip writing and decomposition.
35 E : void TestSymbols(const ImportedModule& module) {
36 : // Check resulting references.
37 E : for (size_t i = 0; i < module.size(); ++i) {
38 E : BlockGraph::Reference ref;
39 E : EXPECT_TRUE(module.GetSymbolReference(i, &ref));
40 E : EXPECT_TRUE(ref.referenced() != NULL);
41 E : EXPECT_GE(ref.offset(), 0);
42 E : EXPECT_LT(ref.offset(),
43 : static_cast<BlockGraph::Offset>(ref.referenced()->size()));
44 E : }
45 :
46 E : ASSERT_NO_FATAL_FAILURE(TestRoundTrip());
47 E : }
48 : };
49 :
50 : const char kFunction1Name[] = "__imp_?function1@@YAHXZ";
51 : const char kFunction3Name[] = "?function3@@YAHXZ";
52 : const char kFunction4Name[] = "?function4@@YAHXZ";
53 : const char kMemcpy[] = "_memset"; // Multiply defined.
54 :
55 : } // namespace
56 :
57 E : TEST_F(CoffAddImportsTransformTest, AddImportsExisting) {
58 E : ASSERT_NO_FATAL_FAILURE(DecomposeOriginal());
59 E : ImportedModule module("export_dll.dll");
60 : size_t function1 = module.AddSymbol(kFunction1Name,
61 E : ImportedModule::kAlwaysImport);
62 : size_t function3 = module.AddSymbol(kFunction3Name,
63 E : ImportedModule::kAlwaysImport);
64 :
65 E : CoffAddImportsTransform transform;
66 E : transform.AddModule(&module);
67 : EXPECT_TRUE(block_graph::ApplyBlockGraphTransform(
68 E : &transform, &policy_, &block_graph_, headers_block_));
69 E : EXPECT_EQ(0u, transform.modules_added());
70 E : EXPECT_EQ(0u, transform.symbols_added());
71 :
72 E : EXPECT_TRUE(module.ModuleIsImported());
73 E : EXPECT_TRUE(module.SymbolIsImported(function1));
74 E : EXPECT_TRUE(module.SymbolIsImported(function3));
75 :
76 E : EXPECT_FALSE(module.ModuleWasAdded());
77 E : EXPECT_FALSE(module.SymbolWasAdded(function1));
78 E : EXPECT_FALSE(module.SymbolWasAdded(function3));
79 :
80 E : EXPECT_NO_FATAL_FAILURE(TestSymbols(module));
81 E : }
82 :
83 E : TEST_F(CoffAddImportsTransformTest, AddImportsNewSymbol) {
84 E : ASSERT_NO_FATAL_FAILURE(DecomposeOriginal());
85 E : ImportedModule module("export_dll.dll");
86 : size_t function1 = module.AddSymbol(kFunction1Name,
87 E : ImportedModule::kAlwaysImport);
88 : size_t function3 = module.AddSymbol(kFunction3Name,
89 E : ImportedModule::kAlwaysImport);
90 : size_t function4 = module.AddSymbol(kFunction4Name,
91 E : ImportedModule::kAlwaysImport);
92 :
93 E : CoffAddImportsTransform transform;
94 E : transform.AddModule(&module);
95 : EXPECT_TRUE(block_graph::ApplyBlockGraphTransform(
96 E : &transform, &policy_, &block_graph_, headers_block_));
97 E : EXPECT_EQ(0u, transform.modules_added());
98 E : EXPECT_EQ(1u, transform.symbols_added());
99 :
100 E : EXPECT_TRUE(module.ModuleIsImported());
101 E : EXPECT_TRUE(module.SymbolIsImported(function1));
102 E : EXPECT_TRUE(module.SymbolIsImported(function3));
103 E : EXPECT_TRUE(module.SymbolIsImported(function4));
104 :
105 E : EXPECT_FALSE(module.ModuleWasAdded());
106 E : EXPECT_FALSE(module.SymbolWasAdded(function1));
107 E : EXPECT_FALSE(module.SymbolWasAdded(function3));
108 E : EXPECT_TRUE(module.SymbolWasAdded(function4));
109 :
110 E : EXPECT_NO_FATAL_FAILURE(TestSymbols(module));
111 E : }
112 :
113 E : TEST_F(CoffAddImportsTransformTest, FindImportsExistingMultiple) {
114 E : ASSERT_NO_FATAL_FAILURE(DecomposeOriginal());
115 E : ImportedModule module("export_dll.dll");
116 : size_t function1 = module.AddSymbol(kFunction1Name,
117 E : ImportedModule::kFindOnly);
118 : size_t function3 = module.AddSymbol(kFunction3Name,
119 E : ImportedModule::kFindOnly);
120 :
121 E : CoffAddImportsTransform transform;
122 E : transform.AddModule(&module);
123 : EXPECT_TRUE(block_graph::ApplyBlockGraphTransform(
124 E : &transform, &policy_, &block_graph_, headers_block_));
125 E : EXPECT_EQ(0u, transform.modules_added());
126 E : EXPECT_EQ(0u, transform.symbols_added());
127 :
128 E : EXPECT_TRUE(module.ModuleIsImported());
129 E : EXPECT_TRUE(module.SymbolIsImported(function1));
130 E : EXPECT_TRUE(module.SymbolIsImported(function3));
131 :
132 E : EXPECT_FALSE(module.ModuleWasAdded());
133 E : EXPECT_FALSE(module.SymbolWasAdded(function1));
134 E : EXPECT_FALSE(module.SymbolWasAdded(function3));
135 E : }
136 :
137 E : TEST_F(CoffAddImportsTransformTest, FindImportsNewSymbol) {
138 E : ASSERT_NO_FATAL_FAILURE(DecomposeOriginal());
139 E : ImportedModule module("export_dll.dll");
140 : size_t function1 = module.AddSymbol(kFunction1Name,
141 E : ImportedModule::kFindOnly);
142 : size_t function3 = module.AddSymbol(kFunction3Name,
143 E : ImportedModule::kFindOnly);
144 : size_t function4 = module.AddSymbol(kFunction4Name,
145 E : ImportedModule::kFindOnly);
146 :
147 E : CoffAddImportsTransform transform;
148 E : transform.AddModule(&module);
149 : EXPECT_TRUE(block_graph::ApplyBlockGraphTransform(
150 E : &transform, &policy_, &block_graph_, headers_block_));
151 E : EXPECT_EQ(0u, transform.modules_added());
152 E : EXPECT_EQ(0u, transform.symbols_added());
153 :
154 E : EXPECT_TRUE(module.ModuleIsImported());
155 E : EXPECT_TRUE(module.SymbolIsImported(function1));
156 E : EXPECT_TRUE(module.SymbolIsImported(function3));
157 E : EXPECT_FALSE(module.SymbolIsImported(function4));
158 :
159 E : EXPECT_FALSE(module.ModuleWasAdded());
160 E : EXPECT_FALSE(module.SymbolWasAdded(function1));
161 E : EXPECT_FALSE(module.SymbolWasAdded(function3));
162 E : EXPECT_FALSE(module.SymbolWasAdded(function4));
163 E : }
164 :
165 E : TEST_F(CoffAddImportsTransformTest, EmptyStringTable) {
166 : // Override with a different module.
167 : test_dll_obj_path_ = testing::GetSrcRelativePath(
168 E : testing::kEmptyStringTableCoffName);
169 E : ASSERT_NO_FATAL_FAILURE(DecomposeOriginal());
170 E : ImportedModule module("export_dll.dll");
171 : size_t function1 = module.AddSymbol(kFunction1Name,
172 E : ImportedModule::kAlwaysImport);
173 :
174 E : CoffAddImportsTransform transform;
175 E : transform.AddModule(&module);
176 : EXPECT_TRUE(block_graph::ApplyBlockGraphTransform(
177 E : &transform, &policy_, &block_graph_, headers_block_));
178 E : EXPECT_EQ(0u, transform.modules_added());
179 E : EXPECT_EQ(1u, transform.symbols_added());
180 :
181 E : EXPECT_TRUE(module.ModuleIsImported());
182 E : EXPECT_TRUE(module.SymbolIsImported(function1));
183 :
184 E : EXPECT_FALSE(module.ModuleWasAdded());
185 E : EXPECT_TRUE(module.SymbolWasAdded(function1));
186 E : }
187 :
188 :
189 : } // namespace transforms
190 : } // namespace pe
|